From c26af11f7660c6f417de3d39245e9c0d6f575a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Sat, 20 Dec 2025 15:24:28 +0800 Subject: [PATCH 01/24] feat(apikey system): tool system development --- api/app/controllers/__init__.py | 6 +- api/app/controllers/tool_controller.py | 585 ++++++++++++++++ .../controllers/tool_execution_controller.py | 430 ++++++++++++ api/app/core/api_key_auth.py | 5 +- api/app/core/config.py | 6 + api/app/core/tools/__init__.py | 37 ++ api/app/core/tools/base.py | 302 +++++++++ api/app/core/tools/builtin/__init__.py | 17 + .../core/tools/builtin/baidu_search_tool.py | 334 ++++++++++ api/app/core/tools/builtin/base.py | 118 ++++ api/app/core/tools/builtin/datetime_tool.py | 307 +++++++++ api/app/core/tools/builtin/json_tool.py | 430 ++++++++++++ api/app/core/tools/builtin/mineru_tool.py | 327 +++++++++ api/app/core/tools/builtin/textin_tool.py | 401 +++++++++++ api/app/core/tools/chain_manager.py | 485 ++++++++++++++ api/app/core/tools/config_manager.py | 264 ++++++++ .../configs/builtin/baidu_search_tool.json | 14 + .../tools/configs/builtin/datetime_tool.json | 12 + .../core/tools/configs/builtin/json_tool.json | 12 + .../tools/configs/builtin/mineru_tool.json | 14 + .../tools/configs/builtin/textin_tool.json | 14 + api/app/core/tools/configs/builtin_tools.json | 60 ++ api/app/core/tools/custom/__init__.py | 11 + api/app/core/tools/custom/auth_manager.py | 525 +++++++++++++++ api/app/core/tools/custom/base.py | 318 +++++++++ api/app/core/tools/custom/schema_parser.py | 477 +++++++++++++ api/app/core/tools/executor.py | 501 ++++++++++++++ api/app/core/tools/langchain_adapter.py | 375 +++++++++++ api/app/core/tools/mcp/__init__.py | 12 + api/app/core/tools/mcp/base.py | 258 ++++++++ api/app/core/tools/mcp/client.py | 626 ++++++++++++++++++ api/app/core/tools/mcp/service_manager.py | 604 +++++++++++++++++ api/app/core/tools/registry.py | 436 ++++++++++++ api/app/core/workflow/executor.py | 182 +++++ api/app/models/__init__.py | 16 +- api/app/models/tenant_model.py | 3 + api/app/models/tool_model.py | 226 +++++++ api/app/services/agent_tools.py | 218 ++++++ api/test_tool_system.py | 374 +++++++++++ 39 files changed, 9338 insertions(+), 4 deletions(-) create mode 100644 api/app/controllers/tool_controller.py create mode 100644 api/app/controllers/tool_execution_controller.py create mode 100644 api/app/core/tools/__init__.py create mode 100644 api/app/core/tools/base.py create mode 100644 api/app/core/tools/builtin/__init__.py create mode 100644 api/app/core/tools/builtin/baidu_search_tool.py create mode 100644 api/app/core/tools/builtin/base.py create mode 100644 api/app/core/tools/builtin/datetime_tool.py create mode 100644 api/app/core/tools/builtin/json_tool.py create mode 100644 api/app/core/tools/builtin/mineru_tool.py create mode 100644 api/app/core/tools/builtin/textin_tool.py create mode 100644 api/app/core/tools/chain_manager.py create mode 100644 api/app/core/tools/config_manager.py create mode 100644 api/app/core/tools/configs/builtin/baidu_search_tool.json create mode 100644 api/app/core/tools/configs/builtin/datetime_tool.json create mode 100644 api/app/core/tools/configs/builtin/json_tool.json create mode 100644 api/app/core/tools/configs/builtin/mineru_tool.json create mode 100644 api/app/core/tools/configs/builtin/textin_tool.json create mode 100644 api/app/core/tools/configs/builtin_tools.json create mode 100644 api/app/core/tools/custom/__init__.py create mode 100644 api/app/core/tools/custom/auth_manager.py create mode 100644 api/app/core/tools/custom/base.py create mode 100644 api/app/core/tools/custom/schema_parser.py create mode 100644 api/app/core/tools/executor.py create mode 100644 api/app/core/tools/langchain_adapter.py create mode 100644 api/app/core/tools/mcp/__init__.py create mode 100644 api/app/core/tools/mcp/base.py create mode 100644 api/app/core/tools/mcp/client.py create mode 100644 api/app/core/tools/mcp/service_manager.py create mode 100644 api/app/core/tools/registry.py create mode 100644 api/app/models/tool_model.py create mode 100644 api/test_tool_system.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index a3caaf4a..fe7c692e 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -28,7 +28,9 @@ from . import ( public_share_controller, multi_agent_controller, workflow_controller, - prompt_optimizer_controller + prompt_optimizer_controller, + tool_controller, + tool_execution_controller, ) # 创建管理端 API 路由器 @@ -60,5 +62,7 @@ manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) manager_router.include_router(workflow_controller.router) manager_router.include_router(prompt_optimizer_controller.router) +manager_router.include_router(tool_controller.router) +manager_router.include_router(tool_execution_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py new file mode 100644 index 00000000..433392d2 --- /dev/null +++ b/api/app/controllers/tool_controller.py @@ -0,0 +1,585 @@ +"""工具管理API控制器""" +import base64 +from typing import List, Optional, Dict, Any + +from fastapi import APIRouter, Depends, HTTPException, Body +from langfuse.api.core import jsonable_encoder +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field, PositiveInt, field_validator +from cryptography.fernet import Fernet + +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig +from app.core.logging_config import get_business_logger +from app.core.config import settings +from app.core.tools.config_manager import ConfigManager + +logger = get_business_logger() + +router = APIRouter(prefix="/tools", tags=["工具管理"]) + + +# ==================== 辅助函数 ==================== + + +def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: + """加密敏感参数""" + cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) + cipher = Fernet(cipher_key) + + encrypted_params = {} + sensitive_keys = ['api_key', 'token', 'api_secret', 'password'] + + for key, value in parameters.items(): + if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: + encrypted_params[key] = cipher.encrypt(str(value).encode()).decode() + else: + encrypted_params[key] = value + + return encrypted_params + + +def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: + """解密敏感参数""" + cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) + cipher = Fernet(cipher_key) + + decrypted_params = {} + sensitive_keys = ['api_key', 'token', 'secret', 'password'] + + for key, value in parameters.items(): + if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: + try: + decrypted_params[key] = cipher.decrypt(value.encode()).decode() + except Exception as e: + decrypted_params[key] = value + else: + decrypted_params[key] = value + + return decrypted_params + + +def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str: + """更新工具状态并返回新状态""" + if tool_config.tool_type == ToolType.BUILTIN: + if not tool_info or not tool_info.get('requires_config', False): + new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具 + elif not builtin_config or not builtin_config.parameters: + new_status = ToolStatus.INACTIVE.value + else: + # 检查是否有必要的API密钥 + has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) + new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value + else: # 自定义和MCP工具 + new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value + + # 更新数据库中的状态 + if tool_config.status != new_status: + tool_config.status = new_status + + return new_status + + +# ==================== 请求/响应模型 ==================== + +class ToolListResponse(BaseModel): + """工具列表响应""" + id: str + name: str + description: str + tool_type: str + category: str + version: str = "1.0.0" + status: str # active inactive error loading + requires_config: bool = False + # is_configured: bool = False + + class Config: + from_attributes = True + +class BuiltinToolConfigRequest(BaseModel): + """内置工具配置请求""" + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + + +class CustomToolCreateRequest(BaseModel): + """自定义工具创建请求体模型,包含参数校验规则""" + name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填") + description: str = Field(None, description="工具描述") + base_url: str = Field(None, description="工具基础URL") + schema_url: str = Field(None, description="工具Schema URL") + schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选") + auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型") + auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典") + timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30") + + # 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段 + @field_validator("auth_config") + def validate_auth_config(cls, v, values): + auth_type = values.data.get("auth_type") + if auth_type == "api_key" and (not v or "api_key" not in v): + raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段") + if auth_type == "bearer_token" and (not v or "bearer_token" not in v): + raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段") + return v + +class MCPToolCreateRequest(BaseModel): + """MCP工具创建请求体模型,适配MCP业务特性""" + # 基础必填字段(带长度/格式校验) + name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称") + description: str = Field(None, description="MCP工具描述") + # MCP核心字段:服务端URL(强制HTTP/HTTPS格式) + server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议") + # 连接配置:默认空字典,可自定义校验规则(根据实际业务调整) + connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典") + + @field_validator("connection_config") + def validate_connection_config(cls, v): + # 示例1:若包含timeout,必须是1-300的整数 + if "timeout" in v: + timeout = v["timeout"] + if not isinstance(timeout, int) or timeout < 1 or timeout > 300: + raise ValueError("connection_config.timeout必须是1-300的整数") + return v + + # @field_validator("server_url") + # def validate_server_url_protocol(cls, v): + # if v.scheme != "https": + # raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)") + # return v + + +# ==================== API端点 ==================== +@router.get("", response_model=List[ToolListResponse]) +async def list_tools( + name: Optional[str] = None, + tool_type: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取工具列表(包含内置工具、自定义工具和MCP工具)""" + try: + # 初始化内置工具(如果需要) + config_manager = ConfigManager() + config_manager.ensure_builtin_tools_initialized( + current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus + ) + + response_tools = [] + + query = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id + ) + if tool_type: + query = query.filter(ToolConfig.tool_type == tool_type) + + if name: + query = query.filter(ToolConfig.name.ilike(f"%{name}%")) + + tools = query.all() + builtin_tools = config_manager.load_builtin_tools_config() + configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} + + for tool_config in tools: + if tool_config.tool_type == ToolType.BUILTIN.value: + builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() + tool_info = configured_tools.get(builtin_config.tool_class) + status = _update_tool_status(tool_config, builtin_config, tool_info) + else: + status = _update_tool_status(tool_config) + + response_tools.append(ToolListResponse( + id=str(tool_config.id), + name=tool_config.name, + description=tool_config.description, + tool_type=tool_config.tool_type, + category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type, + version="1.0.0", + status=status, + requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False, + )) + + return response_tools + except Exception as e: + logger.error(f"获取工具列表失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/builtin/{tool_id}") +async def get_builtin_tool_detail( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取内置工具详情""" + try: + config_manager = ConfigManager() + builtin_tools = config_manager.load_builtin_tools_config() + configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id + ).first() + builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() + tool_info = configured_tools.get(builtin_config.tool_class) + + is_configured = False + config_parameters = {} + + if builtin_config and builtin_config.parameters: + is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) + # 不返回敏感信息,只返回非敏感配置 + config_parameters = {k: v for k, v in builtin_config.parameters.items() + if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])} + + return { + "id": tool_config.id, + "name": tool_config.name, + "description": tool_config.description, + "category": tool_info['category'], + "status": tool_config.tool_type, + "requires_config": tool_info['requires_config'], + "is_configured": is_configured, + "config_parameters": config_parameters + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/builtin/{tool_id}/configure") +async def configure_builtin_tool( + tool_id: str, + request: BuiltinToolConfigRequest = Body(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """配置内置工具参数(租户级别)""" + try: + # 查询工具配置 + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id, + ToolConfig.tool_type == ToolType.BUILTIN + ).first() + + if not tool_config: + raise HTTPException(status_code=404, detail="工具不存在") + + # 获取内置工具配置 + builtin_config = db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if not builtin_config: + raise HTTPException(status_code=404, detail="内置工具配置不存在") + + # 获取全局工具信息 + config_manager = ConfigManager() + builtin_tools_config = config_manager.load_builtin_tools_config() + tool_info = None + for tool_key, info in builtin_tools_config.items(): + if info['tool_class'] == builtin_config.tool_class: + tool_info = info + break + + if not tool_info: + raise HTTPException(status_code=404, detail="工具信息不存在") + + # 加密敏感参数 + encrypted_params = _encrypt_sensitive_params(request.parameters) + + # 更新配置 + builtin_config.parameters = encrypted_params + + # 更新状态 + _update_tool_status(tool_config, builtin_config, tool_info) + + db.commit() + + return { + "success": True, + "message": f"工具 {tool_config.name} 配置成功" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"配置内置工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/builtin/{tool_id}/config") +async def get_builtin_tool_config( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取内置工具配置(用于使用)""" + try: + # 查询工具配置 + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id, + ToolConfig.tool_type == ToolType.BUILTIN + ).first() + + if not tool_config: + raise HTTPException(status_code=404, detail="工具不存在") + + # 获取内置工具配置 + builtin_config = db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if not builtin_config: + raise HTTPException(status_code=404, detail="内置工具配置不存在") + + # 解密参数 + decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {}) + + return { + "tool_id": tool_id, + "tool_class": builtin_config.tool_class, + "name": tool_config.name, + "parameters": decrypted_params, + "status": tool_config.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取工具配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/custom") +async def create_custom_tool( + request: CustomToolCreateRequest = Body(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """创建自定义工具""" + try: + config_data = jsonable_encoder(request.model_dump()) + config_data["tool_type"] = "custom" + + config_manager = ConfigManager() + is_valid, error_msg = config_manager.validate_config(config_data, "custom") + if not is_valid: + raise HTTPException(status_code=400, detail=error_msg) + + # 创建数据库记录 + tool_config = ToolConfig( + name=request.name, + description=request.description, + tool_type=ToolType.CUSTOM, + tenant_id=current_user.tenant_id, + status=ToolStatus.ACTIVE.value, + config_data=config_data + ) + db.add(tool_config) + db.flush() + + # 创建CustomToolConfig记录 + custom_config = CustomToolConfig( + id=tool_config.id, + base_url=request.base_url, + schema_url=request.schema_url, + schema_content=request.schema_content, + auth_type=request.auth_type, + auth_config=request.auth_config, + timeout=request.timeout + ) + db.add(custom_config) + + db.commit() + + return { + "success": True, + "message": f"自定义工具 {request.name} 创建成功", + "tool_id": str(tool_config.id) + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"创建自定义工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/mcp") +async def create_mcp_tool( + request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """创建MCP工具""" + try: + config_data = jsonable_encoder(request.model_dump()) + config_data["tool_type"] = "mcp" + + config_manager = ConfigManager() + is_valid, error_msg = config_manager.validate_config(config_data, "mcp") + if not is_valid: + raise HTTPException(status_code=400, detail=error_msg) + + # 创建数据库记录 + try: + tool_config = ToolConfig( + name=request.name, + description=request.description, + tool_type=ToolType.MCP, + tenant_id=current_user.tenant_id, + status=ToolStatus.ACTIVE.value, + config_data=config_data + ) + db.add(tool_config) + db.flush() + + # 创建MCPToolConfig记录 + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=request.server_url, + connection_config=request.connection_config + ) + db.add(mcp_config) + + db.commit() + except SQLAlchemyError as db_e: + db.rollback() + logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}", + exc_info=True) + raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id}," + f"工具名:{request.name}):{str(db_e)}") + + return { + "success": True, + "message": f"MCP工具 {request.name} 创建成功", + "tool_id": str(tool_config.id) + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"创建MCP工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@router.delete("/{tool_id}") +async def delete_tool( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """删除工具(仅限自定义和MCP工具)""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + if tool.tool_type == ToolType.BUILTIN: + raise HTTPException(status_code=403, detail="内置工具不允许删除") + + db.delete(tool) + db.commit() + + return { + "success": True, + "message": f"工具 {tool.name} 删除成功" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"删除工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put("/{tool_id}") +async def update_tool( + tool_id: str, + config_data: Optional[Dict[str, Any]] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """更新工具(仅限自定义和MCP工具)""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + if tool.tool_type == ToolType.BUILTIN: + raise HTTPException(status_code=403, detail="内置工具不允许修改") + + if config_data is not None: + tool.config_data = config_data + # 更新状态 + _update_tool_status(tool) + + db.commit() + db.refresh(tool) + + return { + "success": True, + "message": f"工具 {tool.name} 更新成功", + "status": tool.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"更新工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{tool_id}/toggle") +async def toggle_tool_status( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """切换工具活跃/非活跃状态""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + # 在active和inactive之间切换 + if tool.status == ToolStatus.ACTIVE.value: + tool.status = ToolStatus.INACTIVE.value + elif tool.status == ToolStatus.INACTIVE.value: + tool.status = ToolStatus.ACTIVE.value + else: + raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换") + + db.commit() + db.refresh(tool) + + return { + "success": True, + "message": f"工具 {tool.name} 状态已更新为 {tool.status}", + "status": tool.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"切换工具状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/api/app/controllers/tool_execution_controller.py b/api/app/controllers/tool_execution_controller.py new file mode 100644 index 00000000..486eb7cf --- /dev/null +++ b/api/app/controllers/tool_execution_controller.py @@ -0,0 +1,430 @@ +"""工具执行API控制器""" +import uuid +from typing import Dict, Any, List, Optional +from fastapi import APIRouter, Depends, HTTPException, Path, Query +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field + +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.core.tools.registry import ToolRegistry +from app.core.tools.executor import ToolExecutor +from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode +from app.core.tools.builtin import * +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + +router = APIRouter(prefix="/tools/execution", tags=["工具执行"]) + + +# ==================== 请求/响应模型 ==================== + +class ToolExecutionRequest(BaseModel): + """工具执行请求""" + tool_id: str = Field(..., description="工具ID") + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)") + metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据") + + +class BatchExecutionRequest(BaseModel): + """批量执行请求""" + executions: List[ToolExecutionRequest] = Field(..., description="执行列表") + max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数") + + +class ToolExecutionResponse(BaseModel): + """工具执行响应""" + success: bool + execution_id: str + tool_id: str + data: Any = None + error: Optional[str] = None + error_code: Optional[str] = None + execution_time: float + token_usage: Optional[Dict[str, int]] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class ChainStepRequest(BaseModel): + """链步骤请求""" + tool_id: str = Field(..., description="工具ID") + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + condition: Optional[str] = Field(None, description="执行条件") + output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射") + error_handling: str = Field("stop", description="错误处理策略") + + +class ChainExecutionRequest(BaseModel): + """链执行请求""" + name: str = Field(..., description="链名称") + description: str = Field("", description="链描述") + steps: List[ChainStepRequest] = Field(..., description="执行步骤") + execution_mode: str = Field("sequential", description="执行模式") + initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量") + global_timeout: Optional[float] = Field(None, description="全局超时") + + +class ExecutionHistoryResponse(BaseModel): + """执行历史响应""" + execution_id: str + tool_id: str + status: str + started_at: Optional[str] + completed_at: Optional[str] + execution_time: Optional[float] + user_id: Optional[str] + workspace_id: Optional[str] + input_data: Optional[Dict[str, Any]] + output_data: Optional[Any] + error_message: Optional[str] + token_usage: Optional[Dict[str, int]] + + +class ToolConnectionTestResponse(BaseModel): + """工具连接测试响应""" + success: bool + message: str + error: Optional[str] = None + details: Optional[Dict[str, Any]] = None + + +# ==================== 依赖注入 ==================== + +def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry: + """获取工具注册表""" + registry = ToolRegistry(db) + + # 注册内置工具类 + registry.register_tool_class(DateTimeTool) + registry.register_tool_class(JsonTool) + registry.register_tool_class(BaiduSearchTool) + registry.register_tool_class(MinerUTool) + registry.register_tool_class(TextInTool) + + return registry + + +def get_tool_executor( + db: Session = Depends(get_db), + registry: ToolRegistry = Depends(get_tool_registry) +) -> ToolExecutor: + """获取工具执行器""" + return ToolExecutor(db, registry) + + +def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager: + """获取链管理器""" + return ChainManager(executor) + + +# ==================== API端点 ==================== + +@router.post("/execute", response_model=ToolExecutionResponse) +async def execute_tool( + request: ToolExecutionRequest, + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """执行单个工具""" + try: + # 生成执行ID + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + + # 执行工具 + result = await executor.execute_tool( + tool_id=request.tool_id, + parameters=request.parameters, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + execution_id=execution_id, + timeout=request.timeout, + metadata=request.metadata + ) + + return ToolExecutionResponse( + success=result.success, + execution_id=execution_id, + tool_id=request.tool_id, + data=result.data, + error=result.error, + error_code=result.error_code, + execution_time=result.execution_time, + token_usage=result.token_usage, + metadata=result.metadata + ) + + except Exception as e: + logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/batch", response_model=List[ToolExecutionResponse]) +async def execute_tools_batch( + request: BatchExecutionRequest, + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """批量执行工具""" + try: + # 准备执行配置 + execution_configs = [] + execution_ids = [] + + for exec_request in request.executions: + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + execution_ids.append(execution_id) + + execution_configs.append({ + "tool_id": exec_request.tool_id, + "parameters": exec_request.parameters, + "user_id": current_user.id, + "workspace_id": current_user.current_workspace_id, + "execution_id": execution_id, + "timeout": exec_request.timeout, + "metadata": exec_request.metadata + }) + + # 批量执行 + results = await executor.execute_tools_batch( + execution_configs, + max_concurrency=request.max_concurrency + ) + + # 转换响应格式 + responses = [] + for i, result in enumerate(results): + responses.append(ToolExecutionResponse( + success=result.success, + execution_id=execution_ids[i], + tool_id=request.executions[i].tool_id, + data=result.data, + error=result.error, + error_code=result.error_code, + execution_time=result.execution_time, + token_usage=result.token_usage, + metadata=result.metadata + )) + + return responses + + except Exception as e: + logger.error(f"批量执行失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/chain", response_model=Dict[str, Any]) +async def execute_tool_chain( + request: ChainExecutionRequest, + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """执行工具链""" + try: + # 转换步骤格式 + steps = [] + for step_request in request.steps: + step = ChainStep( + tool_id=step_request.tool_id, + parameters=step_request.parameters, + condition=step_request.condition, + output_mapping=step_request.output_mapping, + error_handling=step_request.error_handling + ) + steps.append(step) + + # 创建链定义 + chain_definition = ChainDefinition( + name=request.name, + description=request.description, + steps=steps, + execution_mode=ChainExecutionMode(request.execution_mode), + global_timeout=request.global_timeout + ) + + # 注册并执行链 + chain_manager.register_chain(chain_definition) + + result = await chain_manager.execute_chain( + chain_name=request.name, + initial_variables=request.initial_variables + ) + + return result + + except Exception as e: + logger.error(f"工具链执行失败: {request.name}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/running", response_model=List[Dict[str, Any]]) +async def get_running_executions( + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取正在运行的执行""" + try: + running_executions = executor.get_running_executions() + + # 过滤当前工作空间的执行 + workspace_executions = [ + exec_info for exec_info in running_executions + if exec_info.get("workspace_id") == str(current_user.current_workspace_id) + ] + + return workspace_executions + + except Exception as e: + logger.error(f"获取运行中执行失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any]) +async def cancel_execution( + execution_id: str = Path(..., description="执行ID"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """取消工具执行""" + try: + success = await executor.cancel_execution(execution_id) + + if success: + return { + "success": True, + "message": "执行已取消" + } + else: + raise HTTPException(status_code=404, detail="执行不存在或已完成") + + except HTTPException: + raise + except Exception as e: + logger.error(f"取消执行失败: {execution_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/history", response_model=List[ExecutionHistoryResponse]) +async def get_execution_history( + tool_id: Optional[str] = Query(None, description="工具ID过滤"), + limit: int = Query(50, ge=1, le=200, description="返回数量限制"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取执行历史""" + try: + history = executor.get_execution_history( + tool_id=tool_id, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + limit=limit + ) + + # 转换响应格式 + responses = [] + for record in history: + responses.append(ExecutionHistoryResponse( + execution_id=record["execution_id"], + tool_id=record["tool_id"], + status=record["status"], + started_at=record["started_at"], + completed_at=record["completed_at"], + execution_time=record["execution_time"], + user_id=record["user_id"], + workspace_id=record["workspace_id"], + input_data=record["input_data"], + output_data=record["output_data"], + error_message=record["error_message"], + token_usage=record["token_usage"] + )) + + return responses + + except Exception as e: + logger.error(f"获取执行历史失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/statistics", response_model=Dict[str, Any]) +async def get_execution_statistics( + days: int = Query(7, ge=1, le=90, description="统计天数"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取执行统计""" + try: + stats = executor.get_execution_statistics( + workspace_id=current_user.current_workspace_id, + days=days + ) + + return { + "success": True, + "statistics": stats + } + + except Exception as e: + logger.error(f"获取执行统计失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/chains/running", response_model=List[Dict[str, Any]]) +async def get_running_chains( + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """获取正在运行的工具链""" + try: + running_chains = chain_manager.get_running_chains() + return running_chains + + except Exception as e: + logger.error(f"获取运行中工具链失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/chains", response_model=List[Dict[str, Any]]) +async def list_tool_chains( + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """列出工具链""" + try: + chains = chain_manager.list_chains() + return chains + + except Exception as e: + logger.error(f"获取工具链列表失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse) +async def test_tool_connection( + tool_id: str = Path(..., description="工具ID"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """测试工具连接""" + try: + result = await executor.test_tool_connection( + tool_id=tool_id, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id + ) + + return ToolConnectionTestResponse( + success=result.get("success", False), + message=result.get("message", ""), + error=result.get("error"), + details=result.get("details") + ) + + except Exception as e: + logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}") + return ToolConnectionTestResponse( + success=False, + message="连接测试失败", + error=str(e) + ) \ No newline at end of file diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index d90bb00d..e1021c6f 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -37,9 +37,10 @@ def require_api_key( @require_api_key(scopes=["app"]) def chat_with_app( resource_id: uuid.UUID, - api_key_auth: ApiKeyAuth = Depends(), + request: Request, + api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), - message: str + message: str = Query(..., description="聊天消息内容") ): # api_key_auth 包含验证后的API Key 信息 pass diff --git a/api/app/core/config.py b/api/app/core/config.py index 48f79d5e..d4d285fe 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -156,6 +156,12 @@ class Settings: MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json") MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json") + # Tool Management Configuration + TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") + TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) + TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10")) + ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" + def get_memory_output_path(self, filename: str = "") -> str: """ Get the full path for memory module output files. diff --git a/api/app/core/tools/__init__.py b/api/app/core/tools/__init__.py new file mode 100644 index 00000000..109bac13 --- /dev/null +++ b/api/app/core/tools/__init__.py @@ -0,0 +1,37 @@ +"""工具管理核心模块""" + +from .base import BaseTool, ToolResult, ToolParameter +from .registry import ToolRegistry +from .executor import ToolExecutor +from .langchain_adapter import LangchainAdapter +from .config_manager import ConfigManager +from .chain_manager import ChainManager + +# 可选导入,避免导入错误 +try: + from .custom.base import CustomTool +except ImportError: + CustomTool = None + +try: + from .mcp.base import MCPTool +except ImportError: + MCPTool = None + +__all__ = [ + "BaseTool", + "ToolResult", + "ToolParameter", + "ToolRegistry", + "ToolExecutor", + "LangchainAdapter", + "ConfigManager", + "ChainManager" +] + +# 只有在成功导入时才添加到__all__ +if CustomTool: + __all__.append("CustomTool") + +if MCPTool: + __all__.append("MCPTool") \ No newline at end of file diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py new file mode 100644 index 00000000..d674af76 --- /dev/null +++ b/api/app/core/tools/base.py @@ -0,0 +1,302 @@ +"""工具基础接口定义""" +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel, Field +from enum import Enum + +from app.models.tool_model import ToolType, ToolStatus + + +class ParameterType(str, Enum): + """参数类型枚举""" + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +class ToolParameter(BaseModel): + """工具参数定义""" + name: str = Field(..., description="参数名称") + type: ParameterType = Field(..., description="参数类型") + description: str = Field("", description="参数描述") + required: bool = Field(False, description="是否必需") + default: Any = Field(None, description="默认值") + enum: Optional[List[Any]] = Field(None, description="枚举值") + minimum: Optional[Union[int, float]] = Field(None, description="最小值") + maximum: Optional[Union[int, float]] = Field(None, description="最大值") + pattern: Optional[str] = Field(None, description="正则表达式模式") + + class Config: + use_enum_values = True + + +class ToolResult(BaseModel): + """工具执行结果""" + success: bool = Field(..., description="执行是否成功") + data: Any = Field(None, description="返回数据") + error: Optional[str] = Field(None, description="错误信息") + error_code: Optional[str] = Field(None, description="错误代码") + execution_time: float = Field(..., description="执行时间(秒)") + token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况") + metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据") + + @classmethod + def success_result( + cls, + data: Any, + execution_time: float, + token_usage: Optional[Dict[str, int]] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建成功结果""" + return cls( + success=True, + data=data, + execution_time=execution_time, + token_usage=token_usage, + metadata=metadata or {} + ) + + @classmethod + def error_result( + cls, + error: str, + execution_time: float, + error_code: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建错误结果""" + return cls( + success=False, + error=error, + error_code=error_code, + execution_time=execution_time, + metadata=metadata or {} + ) + + +class ToolInfo(BaseModel): + """工具信息""" + id: str = Field(..., description="工具ID") + name: str = Field(..., description="工具名称") + description: str = Field(..., description="工具描述") + tool_type: ToolType = Field(..., description="工具类型") + version: str = Field("1.0.0", description="工具版本") + parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") + status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态") + tags: List[str] = Field(default_factory=list, description="工具标签") + tenant_id: Optional[str] = Field(None, description="租户ID") + + class Config: + use_enum_values = True + + +class BaseTool(ABC): + """所有工具的基础抽象类""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + self.tool_id = tool_id + self.config = config + self._status = ToolStatus.ACTIVE + + @property + @abstractmethod + def name(self) -> str: + """工具名称""" + pass + + @property + @abstractmethod + def description(self) -> str: + """工具描述""" + pass + + @property + @abstractmethod + def tool_type(self) -> ToolType: + """工具类型""" + pass + + @property + def version(self) -> str: + """工具版本""" + return self.config.get("version", "1.0.0") + + @property + def status(self) -> ToolStatus: + """工具状态""" + return self._status + + @status.setter + def status(self, value: ToolStatus): + """设置工具状态""" + self._status = value + + @property + @abstractmethod + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + pass + + @property + def tags(self) -> List[str]: + """工具标签""" + return self.config.get("tags", []) + + def get_info(self) -> ToolInfo: + """获取工具信息""" + return ToolInfo( + id=self.tool_id, + name=self.name, + description=self.description, + tool_type=self.tool_type, + version=self.version, + parameters=self.parameters, + status=self.status, + tags=self.tags, + tenant_id=self.config.get("tenant_id") + ) + + def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]: + """验证参数 + + Args: + parameters: 输入参数 + + Returns: + 验证错误字典,空字典表示验证通过 + """ + errors = {} + param_definitions = {p.name: p for p in self.parameters} + + # 检查必需参数 + for param_def in self.parameters: + if param_def.required and param_def.name not in parameters: + errors[param_def.name] = f"Required parameter '{param_def.name}' is missing" + + # 检查参数类型和约束 + for param_name, param_value in parameters.items(): + if param_name not in param_definitions: + continue + + param_def = param_definitions[param_name] + + # 类型检查 + if not self._validate_parameter_type(param_value, param_def): + errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}" + + # 约束检查 + constraint_error = self._validate_parameter_constraints(param_value, param_def) + if constraint_error: + errors[param_name] = constraint_error + + return errors + + def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool: + """验证参数类型""" + if value is None: + return not param_def.required + + type_mapping = { + ParameterType.STRING: str, + ParameterType.INTEGER: int, + ParameterType.NUMBER: (int, float), + ParameterType.BOOLEAN: bool, + ParameterType.ARRAY: list, + ParameterType.OBJECT: dict + } + + expected_type = type_mapping.get(param_def.type) + if expected_type: + return isinstance(value, expected_type) + + return True + + def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]: + """验证参数约束""" + if value is None: + return None + + # 枚举值检查 + if param_def.enum and value not in param_def.enum: + return f"Value must be one of {param_def.enum}" + + # 数值范围检查 + if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]: + if param_def.minimum is not None and value < param_def.minimum: + return f"Value must be >= {param_def.minimum}" + if param_def.maximum is not None and value > param_def.maximum: + return f"Value must be <= {param_def.maximum}" + + # 字符串模式检查 + if param_def.type == ParameterType.STRING and param_def.pattern: + import re + if not re.match(param_def.pattern, str(value)): + return f"Value must match pattern: {param_def.pattern}" + + return None + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """执行工具 + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + pass + + async def safe_execute(self, **kwargs) -> ToolResult: + """安全执行工具(包含参数验证和异常处理) + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + start_time = time.time() + + try: + # 参数验证 + validation_errors = self.validate_parameters(kwargs) + if validation_errors: + execution_time = time.time() - start_time + error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()]) + return ToolResult.error_result( + error=f"Parameter validation failed: {error_msg}", + error_code="VALIDATION_ERROR", + execution_time=execution_time + ) + + # 执行工具 + result = await self.execute(**kwargs) + return result + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="EXECUTION_ERROR", + execution_time=execution_time + ) + + def to_langchain_tool(self): + """转换为Langchain工具格式""" + from .langchain_adapter import LangchainAdapter + return LangchainAdapter.convert_tool(self) + + def __repr__(self): + return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" \ No newline at end of file diff --git a/api/app/core/tools/builtin/__init__.py b/api/app/core/tools/builtin/__init__.py new file mode 100644 index 00000000..3813402c --- /dev/null +++ b/api/app/core/tools/builtin/__init__.py @@ -0,0 +1,17 @@ +"""内置工具模块""" + +from .base import BuiltinTool +from .datetime_tool import DateTimeTool +from .json_tool import JsonTool +from .baidu_search_tool import BaiduSearchTool +from .mineru_tool import MinerUTool +from .textin_tool import TextInTool + +__all__ = [ + "BuiltinTool", + "DateTimeTool", + "JsonTool", + "BaiduSearchTool", + "MinerUTool", + "TextInTool" +] \ No newline at end of file diff --git a/api/app/core/tools/builtin/baidu_search_tool.py b/api/app/core/tools/builtin/baidu_search_tool.py new file mode 100644 index 00000000..fddd6eb7 --- /dev/null +++ b/api/app/core/tools/builtin/baidu_search_tool.py @@ -0,0 +1,334 @@ +"""百度搜索工具 - 搜索引擎服务""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class BaiduSearchTool(BuiltinTool): + """百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果""" + + @property + def name(self) -> str: + return "baidu_search_tool" + + @property + def description(self) -> str: + return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果" + + def get_required_config_parameters(self) -> List[str]: + return ["api_key"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="query", + type=ParameterType.STRING, + description="搜索关键词", + required=True + ), + ToolParameter( + name="search_type", + type=ParameterType.STRING, + description="搜索类型", + required=False, + default="web", + enum=["web", "news", "image", "video"] + ), + ToolParameter( + name="page_size", + type=ParameterType.INTEGER, + description="每页结果数", + required=False, + default=10, + minimum=1, + maximum=50 + ), + ToolParameter( + name="page_num", + type=ParameterType.INTEGER, + description="页码(从1开始)", + required=False, + default=1, + minimum=1, + maximum=10 + ), + ToolParameter( + name="safe_search", + type=ParameterType.BOOLEAN, + description="是否启用安全搜索", + required=False, + default=True + ), + ToolParameter( + name="region", + type=ParameterType.STRING, + description="搜索地区", + required=False, + default="cn", + enum=["cn", "hk", "tw", "us", "jp", "kr"] + ), + ToolParameter( + name="time_filter", + type=ParameterType.STRING, + description="时间过滤", + required=False, + enum=["all", "day", "week", "month", "year"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行百度搜索""" + start_time = time.time() + + try: + query = kwargs.get("query") + search_type = kwargs.get("search_type", "web") + page_size = kwargs.get("page_size", 10) + page_num = kwargs.get("page_num", 1) + safe_search = kwargs.get("safe_search", True) + region = kwargs.get("region", "cn") + time_filter = kwargs.get("time_filter") + + if not query: + raise ValueError("query 参数是必需的") + + # 根据搜索类型调用不同的API + if search_type == "web": + result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter) + elif search_type == "news": + result = await self._news_search(query, page_size, page_num, region, time_filter) + elif search_type == "image": + result = await self._image_search(query, page_size, page_num, safe_search) + elif search_type == "video": + result = await self._video_search(query, page_size, page_num, safe_search) + else: + raise ValueError(f"不支持的搜索类型: {search_type}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="BAIDU_SEARCH_ERROR", + execution_time=execution_time + ) + + async def _web_search(self, query: str, page_size: int, page_num: int, + safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]: + """网页搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}], + "enable_full_content": True + } + + if time_filter: + time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"} + if time_filter in time_map: + payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}} + payload["search_recency_filter"] = time_filter + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "web", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _news_search(self, query: str, page_size: int, page_num: int, + region: str, time_filter: str = None) -> Dict[str, Any]: + """新闻搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}], + "enable_full_content": True + } + + if time_filter: + time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"} + if time_filter in time_map: + payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}} + payload["search_recency_filter"] = time_filter + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "new", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _image_search(self, query: str, page_size: int, page_num: int, + safe_search: bool) -> Dict[str, Any]: + """图片搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}], + "enable_full_content": True + } + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "image", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _video_search(self, query: str, page_size: int, page_num: int, + safe_search: bool) -> Dict[str, Any]: + """视频搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}], + "enable_full_content": True + } + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "video", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """调用百度AI搜索API""" + api_key = self.get_config_parameter("api_key") + + if not api_key: + raise ValueError("百度搜索API密钥未配置") + + url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=payload) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"HTTP错误: {response.status}") + + async def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + api_key = self.get_config_parameter("api_key") + + if not api_key: + return { + "success": False, + "error": "API密钥未配置" + } + + # 发送测试请求验证API key是否有效 + test_payload = { + "messages": [{"role": "user", "content": "test"}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": 1}] + } + + try: + await self._call_baidu_ai_search_api(test_payload) + return { + "success": True, + "message": "连接测试成功", + "api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***" + } + except Exception as e: + return { + "success": False, + "error": f"API连接失败: {str(e)}" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/base.py b/api/app/core/tools/builtin/base.py new file mode 100644 index 00000000..532d0869 --- /dev/null +++ b/api/app/core/tools/builtin/base.py @@ -0,0 +1,118 @@ +"""内置工具基类""" +from abc import ABC, abstractmethod +from typing import Dict, Any, List + +from app.models.tool_model import ToolType +from app.core.tools.base import BaseTool, ToolResult, ToolParameter + + +class BuiltinTool(BaseTool, ABC): + """内置工具基类""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化内置工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.parameters_config = config.get("parameters", {}) + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.BUILTIN + + @property + @abstractmethod + def name(self) -> str: + """工具名称 - 子类必须实现""" + pass + + @property + @abstractmethod + def description(self) -> str: + """工具描述 - 子类必须实现""" + pass + + @property + @abstractmethod + def parameters(self) -> List[ToolParameter]: + """工具参数定义 - 子类必须实现""" + pass + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """执行工具 - 子类必须实现 + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + pass + + @property + def is_configured(self) -> bool: + """检查工具是否已正确配置""" + required_params = self.get_required_config_parameters() + for param in required_params: + if not self.parameters_config.get(param): + return False + return True + + def get_required_config_parameters(self) -> List[str]: + """获取必需的配置参数列表 + + Returns: + 必需配置参数名称列表 + """ + return [] + + def get_config_parameter(self, name: str, default: Any = None) -> Any: + """获取配置参数值 + + Args: + name: 参数名称 + default: 默认值 + + Returns: + 参数值 + """ + return self.parameters_config.get(name, default) + + def validate_configuration(self) -> tuple[bool, str]: + """验证工具配置 + + Returns: + (是否有效, 错误信息) + """ + if not self.is_configured: + required_params = self.get_required_config_parameters() + missing_params = [p for p in required_params if not self.parameters_config.get(p)] + return False, f"缺少必需的配置参数: {', '.join(missing_params)}" + + return True, "" + + async def safe_execute(self, **kwargs) -> ToolResult: + """安全执行工具(包含配置验证) + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + # 首先验证配置 + is_valid, error_msg = self.validate_configuration() + if not is_valid: + return ToolResult.error_result( + error=f"工具配置无效: {error_msg}", + error_code="CONFIGURATION_ERROR", + execution_time=0.0 + ) + + # 调用父类的安全执行 + return await super().safe_execute(**kwargs) \ No newline at end of file diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py new file mode 100644 index 00000000..475ce7be --- /dev/null +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -0,0 +1,307 @@ +"""时间工具 - 日期时间处理""" +import time +from datetime import datetime, timezone, timedelta +from typing import List +import pytz + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class DateTimeTool(BuiltinTool): + """时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能""" + + @property + def name(self) -> str: + return "datetime_tool" + + @property + def description(self) -> str: + return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算" + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"] + ), + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=False + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="from_timezone", + type=ParameterType.STRING, + description="源时区(如:UTC, Asia/Shanghai)", + required=False, + default="UTC" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="UTC" + ), + ToolParameter( + name="calculation", + type=ParameterType.STRING, + description="时间计算表达式(如:+1d, -2h, +30m)", + required=False + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行时间工具操作""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + + if operation == "now": + result = self._get_current_time(kwargs) + elif operation == "format": + result = self._format_datetime(kwargs) + elif operation == "convert_timezone": + result = self._convert_timezone(kwargs) + elif operation == "timestamp_to_datetime": + result = self._timestamp_to_datetime(kwargs) + elif operation == "datetime_to_timestamp": + result = self._datetime_to_timestamp(kwargs) + elif operation == "calculate": + result = self._calculate_datetime(kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="DATETIME_ERROR", + execution_time=execution_time + ) + + def _get_current_time(self, kwargs) -> dict: + """获取当前时间""" + timezone_str = kwargs.get("to_timezone", "UTC") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + now = datetime.now(tz) + + return { + "datetime": now.strftime(output_format), + "timestamp": int(now.timestamp()), + "timezone": timezone_str, + "iso_format": now.isoformat() + } + + def _format_datetime(self, kwargs) -> dict: + """格式化时间""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + return { + "original": input_value, + "formatted": dt.strftime(output_format), + "timestamp": int(dt.timestamp()), + "iso_format": dt.isoformat() + } + + def _convert_timezone(self, kwargs) -> dict: + """时区转换""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + from_timezone = kwargs.get("from_timezone", "UTC") + to_timezone = kwargs.get("to_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置源时区 + if from_timezone == "UTC": + from_tz = pytz.UTC + else: + from_tz = pytz.timezone(from_timezone) + + # 设置目标时区 + if to_timezone == "UTC": + to_tz = pytz.UTC + else: + to_tz = pytz.timezone(to_timezone) + + # 本地化时间并转换时区 + if dt.tzinfo is None: + dt = from_tz.localize(dt) + + converted_dt = dt.astimezone(to_tz) + + return { + "original": input_value, + "original_timezone": from_timezone, + "converted": converted_dt.strftime(output_format), + "converted_timezone": to_timezone, + "timestamp": int(converted_dt.timestamp()) + } + + def _timestamp_to_datetime(self, kwargs) -> dict: + """时间戳转日期时间""" + input_value = kwargs.get("input_value") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + timezone_str = kwargs.get("to_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 转换时间戳 + timestamp = float(input_value) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + dt = datetime.fromtimestamp(timestamp, tz) + + return { + "timestamp": timestamp, + "datetime": dt.strftime(output_format), + "timezone": timezone_str, + "iso_format": dt.isoformat() + } + + def _datetime_to_timestamp(self, kwargs) -> dict: + """日期时间转时间戳""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + timezone_str = kwargs.get("from_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + # 本地化时间 + if dt.tzinfo is None: + dt = tz.localize(dt) + + return { + "datetime": input_value, + "timezone": timezone_str, + "timestamp": int(dt.timestamp()), + "iso_format": dt.isoformat() + } + + def _calculate_datetime(self, kwargs) -> dict: + """时间计算""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + calculation = kwargs.get("calculation") + timezone_str = kwargs.get("from_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + if not calculation: + raise ValueError("calculation 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + if dt.tzinfo is None: + dt = tz.localize(dt) + + # 解析计算表达式 + delta = self._parse_time_delta(calculation) + calculated_dt = dt + delta + + return { + "original": input_value, + "calculation": calculation, + "result": calculated_dt.strftime(output_format), + "timezone": timezone_str, + "timestamp": int(calculated_dt.timestamp()) + } + + def _parse_time_delta(self, calculation: str) -> timedelta: + """解析时间计算表达式""" + import re + + # 支持的单位:d(天), h(小时), m(分钟), s(秒) + pattern = r'([+-]?\d+)([dhms])' + matches = re.findall(pattern, calculation.lower()) + + if not matches: + raise ValueError(f"无效的时间计算表达式: {calculation}") + + total_delta = timedelta() + + for value_str, unit in matches: + value = int(value_str) + + if unit == 'd': + total_delta += timedelta(days=value) + elif unit == 'h': + total_delta += timedelta(hours=value) + elif unit == 'm': + total_delta += timedelta(minutes=value) + elif unit == 's': + total_delta += timedelta(seconds=value) + + return total_delta \ No newline at end of file diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py new file mode 100644 index 00000000..135d252a --- /dev/null +++ b/api/app/core/tools/builtin/json_tool.py @@ -0,0 +1,430 @@ +"""JSON转换工具 - 数据格式转换""" +import json +import time +from typing import List, Any, Dict +import yaml +import xml.etree.ElementTree as ET +from xml.dom import minidom + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class JsonTool(BuiltinTool): + """JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能""" + + @property + def name(self) -> str: + return "json_tool" + + @property + def description(self) -> str: + return "JSON转换工具 - 数据格式转换:JSON格式化、JSON压缩、JSON验证、格式转换" + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"] + ), + ToolParameter( + name="input_data", + type=ParameterType.STRING, + description="输入数据(JSON字符串、YAML字符串或XML字符串)", + required=True + ), + ToolParameter( + name="indent", + type=ParameterType.INTEGER, + description="JSON格式化缩进空格数", + required=False, + default=2, + minimum=0, + maximum=8 + ), + ToolParameter( + name="ensure_ascii", + type=ParameterType.BOOLEAN, + description="是否确保ASCII编码", + required=False, + default=False + ), + ToolParameter( + name="sort_keys", + type=ParameterType.BOOLEAN, + description="是否对键进行排序", + required=False, + default=False + ), + ToolParameter( + name="merge_data", + type=ParameterType.STRING, + description="要合并的JSON数据(用于merge操作)", + required=False + ), + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(用于extract操作,如:$.user.name)", + required=False + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行JSON工具操作""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + input_data = kwargs.get("input_data") + + if not input_data: + raise ValueError("input_data 参数是必需的") + + if operation == "format": + result = self._format_json(input_data, kwargs) + elif operation == "minify": + result = self._minify_json(input_data) + elif operation == "validate": + result = self._validate_json(input_data) + elif operation == "convert": + result = self._convert_json(input_data) + elif operation == "to_yaml": + result = self._json_to_yaml(input_data) + elif operation == "from_yaml": + result = self._yaml_to_json(input_data, kwargs) + elif operation == "to_xml": + result = self._json_to_xml(input_data) + elif operation == "from_xml": + result = self._xml_to_json(input_data, kwargs) + elif operation == "merge": + result = self._merge_json(input_data, kwargs) + elif operation == "extract": + result = self._extract_json_path(input_data, kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="JSON_ERROR", + execution_time=execution_time + ) + + def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """格式化JSON""" + indent = kwargs.get("indent", 2) + ensure_ascii = kwargs.get("ensure_ascii", False) + sort_keys = kwargs.get("sort_keys", False) + + # 解析JSON + data = json.loads(input_data) + + # 格式化输出 + formatted = json.dumps( + data, + indent=indent, + ensure_ascii=ensure_ascii, + sort_keys=sort_keys, + separators=(',', ': ') + ) + + return { + "original_size": len(input_data), + "formatted_size": len(formatted), + "formatted_json": formatted, + "is_valid": True, + "settings": { + "indent": indent, + "ensure_ascii": ensure_ascii, + "sort_keys": sort_keys + } + } + + def _minify_json(self, input_data: str) -> Dict[str, Any]: + """压缩JSON""" + # 解析并压缩 + data = json.loads(input_data) + minified = json.dumps(data, separators=(',', ':')) + + return { + "original_size": len(input_data), + "minified_size": len(minified), + "compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2), + "minified_json": minified, + "is_valid": True + } + + def _validate_json(self, input_data: str) -> Dict[str, Any]: + """验证JSON""" + try: + data = json.loads(input_data) + + # 统计信息 + stats = self._analyze_json_structure(data) + + return { + "is_valid": True, + "error": None, + "size": len(input_data), + "structure": stats + } + + except json.JSONDecodeError as e: + return { + "is_valid": False, + "error": str(e), + "error_line": getattr(e, 'lineno', None), + "error_column": getattr(e, 'colno', None), + "size": len(input_data) + } + + def _convert_json(self, input_data: str) -> Dict[str, Any]: + """JSON转义""" + data = json.loads(input_data) + converted = json.dumps(data, ensure_ascii=False) + + return { + "converted_json": converted, + "is_valid": True + } + + def _json_to_yaml(self, input_data: str) -> Dict[str, Any]: + """JSON转YAML""" + data = json.loads(input_data) + yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2) + + return { + "original_format": "json", + "target_format": "yaml", + "original_size": len(input_data), + "converted_size": len(yaml_output), + "converted_data": yaml_output + } + + def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """YAML转JSON""" + indent = kwargs.get("indent", 2) + ensure_ascii = kwargs.get("ensure_ascii", False) + + data = yaml.safe_load(input_data) + json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii) + + return { + "original_format": "yaml", + "target_format": "json", + "original_size": len(input_data), + "converted_size": len(json_output), + "converted_data": json_output + } + + def _json_to_xml(self, input_data: str) -> Dict[str, Any]: + """JSON转XML""" + data = json.loads(input_data) + + def dict_to_xml(data, root_name="root"): + """递归转换字典为XML""" + if isinstance(data, dict): + if len(data) == 1 and not root_name == "root": + # 如果字典只有一个键,使用该键作为根元素 + key, value = next(iter(data.items())) + return dict_to_xml(value, key) + + root = ET.Element(root_name) + for key, value in data.items(): + if isinstance(value, (dict, list)): + child = dict_to_xml(value, key) + root.append(child) + else: + child = ET.SubElement(root, key) + child.text = str(value) + return root + + elif isinstance(data, list): + root = ET.Element(root_name) + for i, item in enumerate(data): + if isinstance(item, (dict, list)): + child = dict_to_xml(item, f"item_{i}") + root.append(child) + else: + child = ET.SubElement(root, f"item_{i}") + child.text = str(item) + return root + + else: + root = ET.Element(root_name) + root.text = str(data) + return root + + xml_element = dict_to_xml(data) + xml_string = ET.tostring(xml_element, encoding='unicode') + + # 格式化XML + dom = minidom.parseString(xml_string) + formatted_xml = dom.toprettyxml(indent=" ") + + # 移除空行 + formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()]) + + return { + "original_format": "json", + "target_format": "xml", + "original_size": len(input_data), + "converted_size": len(formatted_xml), + "converted_data": formatted_xml + } + + def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """XML转JSON""" + indent = kwargs.get("indent", 2) + + def xml_to_dict(element): + """递归转换XML元素为字典""" + result = {} + + # 处理属性 + if element.attrib: + result.update(element.attrib) + + # 处理文本内容 + if element.text and element.text.strip(): + if len(element) == 0: # 叶子节点 + return element.text.strip() + else: + result['text'] = element.text.strip() + + # 处理子元素 + for child in element: + child_data = xml_to_dict(child) + if child.tag in result: + # 如果标签已存在,转换为列表 + if not isinstance(result[child.tag], list): + result[child.tag] = [result[child.tag]] + result[child.tag].append(child_data) + else: + result[child.tag] = child_data + + return result + + root = ET.fromstring(input_data) + data = {root.tag: xml_to_dict(root)} + json_output = json.dumps(data, indent=indent, ensure_ascii=False) + + return { + "original_format": "xml", + "target_format": "json", + "original_size": len(input_data), + "converted_size": len(json_output), + "converted_data": json_output + } + + def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """合并JSON""" + merge_data = kwargs.get("merge_data") + if not merge_data: + raise ValueError("merge_data 参数是必需的") + + data1 = json.loads(input_data) + data2 = json.loads(merge_data) + + def deep_merge(dict1, dict2): + """深度合并字典""" + result = dict1.copy() + for key, value in dict2.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = value + return result + + if isinstance(data1, dict) and isinstance(data2, dict): + merged = deep_merge(data1, data2) + elif isinstance(data1, list) and isinstance(data2, list): + merged = data1 + data2 + else: + raise ValueError("无法合并不同类型的数据") + + merged_json = json.dumps(merged, indent=2, ensure_ascii=False) + + return { + "operation": "merge", + "original_size": len(input_data), + "merge_size": len(merge_data), + "result_size": len(merged_json), + "merged_data": merged_json + } + + def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取JSON路径""" + json_path = kwargs.get("json_path") + if not json_path: + raise ValueError("json_path 参数是必需的") + + data = json.loads(input_data) + + # 简单的JSONPath实现(支持基本的点号路径) + try: + result = data + if json_path.startswith('$.'): + path_parts = json_path[2:].split('.') + else: + path_parts = json_path.split('.') + + for part in path_parts: + if part.isdigit(): + result = result[int(part)] + else: + result = result[part] + + extracted_json = json.dumps(result, indent=2, ensure_ascii=False) + + return { + "operation": "extract", + "json_path": json_path, + "found": True, + "extracted_data": extracted_json, + "data_type": type(result).__name__ + } + + except (KeyError, IndexError, TypeError) as e: + return { + "operation": "extract", + "json_path": json_path, + "found": False, + "error": str(e), + "extracted_data": None + } + + def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]: + """分析JSON结构""" + if isinstance(data, dict): + return { + "type": "object", + "keys": len(data), + "depth": depth, + "children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()} + } + elif isinstance(data, list): + return { + "type": "array", + "length": len(data), + "depth": depth, + "item_types": list(set(type(item).__name__ for item in data)) + } + else: + return { + "type": type(data).__name__, + "depth": depth, + "value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/mineru_tool.py b/api/app/core/tools/builtin/mineru_tool.py new file mode 100644 index 00000000..b2a544c0 --- /dev/null +++ b/api/app/core/tools/builtin/mineru_tool.py @@ -0,0 +1,327 @@ +"""MinerU PDF解析工具""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class MinerUTool(BuiltinTool): + """MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能""" + + @property + def name(self) -> str: + return "mineru_tool" + + @property + def description(self) -> str: + return "MinerU - PDF解析工具:PDF解析、表格提取、图片识别、文本提取" + + def get_required_config_parameters(self) -> List[str]: + return ["api_key", "api_url"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"] + ), + ToolParameter( + name="file_content", + type=ParameterType.STRING, + description="PDF文件内容(Base64编码)", + required=False + ), + ToolParameter( + name="file_url", + type=ParameterType.STRING, + description="PDF文件URL", + required=False + ), + ToolParameter( + name="parse_mode", + type=ParameterType.STRING, + description="解析模式", + required=False, + default="auto", + enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"] + ), + ToolParameter( + name="extract_images", + type=ParameterType.BOOLEAN, + description="是否提取图片", + required=False, + default=True + ), + ToolParameter( + name="extract_tables", + type=ParameterType.BOOLEAN, + description="是否提取表格", + required=False, + default=True + ), + ToolParameter( + name="page_range", + type=ParameterType.STRING, + description="页面范围(如:1-5, 1,3,5)", + required=False + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出格式", + required=False, + default="json", + enum=["json", "markdown", "html", "text"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行MinerU PDF解析""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + file_content = kwargs.get("file_content") + file_url = kwargs.get("file_url") + + if not file_content and not file_url: + raise ValueError("必须提供 file_content 或 file_url 参数") + + if operation == "parse_pdf": + result = await self._parse_pdf(kwargs) + elif operation == "extract_text": + result = await self._extract_text(kwargs) + elif operation == "extract_tables": + result = await self._extract_tables(kwargs) + elif operation == "extract_images": + result = await self._extract_images(kwargs) + elif operation == "analyze_layout": + result = await self._analyze_layout(kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="MINERU_ERROR", + execution_time=execution_time + ) + + async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """完整PDF解析""" + parse_mode = kwargs.get("parse_mode", "auto") + extract_images = kwargs.get("extract_images", True) + extract_tables = kwargs.get("extract_tables", True) + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "json") + + # 构建请求参数 + request_data = { + "parse_mode": parse_mode, + "extract_images": extract_images, + "extract_tables": extract_tables, + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + # 添加文件数据 + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + # 调用MinerU API + result = await self._call_mineru_api("parse", request_data) + + return { + "operation": "parse_pdf", + "parse_mode": parse_mode, + "total_pages": result.get("total_pages", 0), + "processed_pages": result.get("processed_pages", 0), + "text_content": result.get("text_content", ""), + "tables": result.get("tables", []), + "images": result.get("images", []), + "layout_info": result.get("layout_info", {}), + "metadata": result.get("metadata", {}), + "processing_time": result.get("processing_time", 0) + } + + async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取文本""" + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "text") + + request_data = { + "operation": "extract_text", + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_text", request_data) + + return { + "operation": "extract_text", + "total_pages": result.get("total_pages", 0), + "text_content": result.get("text_content", ""), + "word_count": len(result.get("text_content", "").split()), + "character_count": len(result.get("text_content", "")), + "pages_text": result.get("pages_text", []) + } + + async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取表格""" + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "json") + + request_data = { + "operation": "extract_tables", + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_tables", request_data) + + return { + "operation": "extract_tables", + "total_tables": result.get("total_tables", 0), + "tables": result.get("tables", []), + "table_locations": result.get("table_locations", []) + } + + async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取图片""" + page_range = kwargs.get("page_range") + + request_data = { + "operation": "extract_images" + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_images", request_data) + + return { + "operation": "extract_images", + "total_images": result.get("total_images", 0), + "images": result.get("images", []), + "image_locations": result.get("image_locations", []) + } + + async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """分析布局""" + page_range = kwargs.get("page_range") + + request_data = { + "operation": "analyze_layout" + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("analyze_layout", request_data) + + return { + "operation": "analyze_layout", + "layout_info": result.get("layout_info", {}), + "page_layouts": result.get("page_layouts", []), + "text_blocks": result.get("text_blocks", []), + "image_blocks": result.get("image_blocks", []), + "table_blocks": result.get("table_blocks", []) + } + + async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: + """调用MinerU API""" + api_key = self.get_config_parameter("api_key") + api_url = self.get_config_parameter("api_url") + timeout_seconds = self.get_config_parameter("timeout", 60) + + if not api_key or not api_url: + raise ValueError("MinerU API配置未完成") + + # 构建完整URL + url = f"{api_url.rstrip('/')}/{endpoint}" + + # 构建请求头 + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # 发送请求 + timeout = aiohttp.ClientTimeout(total=timeout_seconds) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + result = await response.json() + if result.get("success", True): + return result.get("data", result) + else: + raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}") + else: + error_text = await response.text() + raise Exception(f"HTTP错误 {response.status}: {error_text}") + + def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + api_key = self.get_config_parameter("api_key") + api_url = self.get_config_parameter("api_url") + + if not api_key or not api_url: + return { + "success": False, + "error": "API配置未完成" + } + + return { + "success": True, + "message": "连接配置有效", + "api_url": api_url, + "api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/textin_tool.py b/api/app/core/tools/builtin/textin_tool.py new file mode 100644 index 00000000..ec3e214e --- /dev/null +++ b/api/app/core/tools/builtin/textin_tool.py @@ -0,0 +1,401 @@ +"""TextIn OCR文字识别工具""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class TextInTool(BuiltinTool): + """TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别""" + + @property + def name(self) -> str: + return "textin_tool" + + @property + def description(self) -> str: + return "TextIn - OCR文字识别:通用OCR、手写识别、多语言支持、高精度识别" + + def get_required_config_parameters(self) -> List[str]: + return ["app_id", "secret_key", "api_url"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="image_content", + type=ParameterType.STRING, + description="图片内容(Base64编码)", + required=False + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="图片URL", + required=False + ), + ToolParameter( + name="language", + type=ParameterType.STRING, + description="识别语言", + required=False, + default="auto", + enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"] + ), + ToolParameter( + name="recognition_mode", + type=ParameterType.STRING, + description="识别模式", + required=False, + default="general", + enum=["general", "accurate", "handwriting", "formula", "table", "document"] + ), + ToolParameter( + name="return_location", + type=ParameterType.BOOLEAN, + description="是否返回文字位置信息", + required=False, + default=False + ), + ToolParameter( + name="return_confidence", + type=ParameterType.BOOLEAN, + description="是否返回置信度", + required=False, + default=True + ), + ToolParameter( + name="merge_lines", + type=ParameterType.BOOLEAN, + description="是否合并行", + required=False, + default=True + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出格式", + required=False, + default="text", + enum=["text", "json", "structured"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行TextIn OCR识别""" + start_time = time.time() + + try: + image_content = kwargs.get("image_content") + image_url = kwargs.get("image_url") + + if not image_content and not image_url: + raise ValueError("必须提供 image_content 或 image_url 参数") + + language = kwargs.get("language", "auto") + recognition_mode = kwargs.get("recognition_mode", "general") + return_location = kwargs.get("return_location", False) + return_confidence = kwargs.get("return_confidence", True) + merge_lines = kwargs.get("merge_lines", True) + output_format = kwargs.get("output_format", "text") + + # 根据识别模式调用不同的API + if recognition_mode == "general": + result = await self._general_ocr(kwargs) + elif recognition_mode == "accurate": + result = await self._accurate_ocr(kwargs) + elif recognition_mode == "handwriting": + result = await self._handwriting_ocr(kwargs) + elif recognition_mode == "formula": + result = await self._formula_ocr(kwargs) + elif recognition_mode == "table": + result = await self._table_ocr(kwargs) + elif recognition_mode == "document": + result = await self._document_ocr(kwargs) + else: + raise ValueError(f"不支持的识别模式: {recognition_mode}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="TEXTIN_ERROR", + execution_time=execution_time + ) + + async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """通用OCR识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "merge_lines": kwargs.get("merge_lines", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("general_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """高精度OCR识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "merge_lines": kwargs.get("merge_lines", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("accurate_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """手写体识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("handwriting_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """公式识别""" + request_data = { + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "output_latex": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("formula_ocr", request_data) + + return self._format_formula_result(result, kwargs.get("output_format", "text")) + + async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """表格识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "output_excel": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("table_ocr", request_data) + + return self._format_table_result(result, kwargs.get("output_format", "text")) + + async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """文档识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "layout_analysis": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("document_ocr", request_data) + + return self._format_document_result(result, kwargs.get("output_format", "text")) + + def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None: + """格式化OCR结果""" + lines = result.get("lines", []) + + if output_format == "text": + text_content = "\n".join([line.get("text", "") for line in lines]) + return { + "recognition_mode": "ocr", + "text_content": text_content, + "line_count": len(lines), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + elif output_format == "json": + return { + "recognition_mode": "ocr", + "lines": lines, + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + elif output_format == "structured": + return { + "recognition_mode": "ocr", + "text_content": "\n".join([line.get("text", "") for line in lines]), + "structured_data": { + "lines": lines, + "paragraphs": self._group_lines_to_paragraphs(lines), + "statistics": { + "line_count": len(lines), + "word_count": sum(len(line.get("text", "").split()) for line in lines), + "character_count": sum(len(line.get("text", "")) for line in lines) + } + }, + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化公式识别结果""" + formulas = result.get("formulas", []) + + return { + "recognition_mode": "formula", + "formula_count": len(formulas), + "formulas": formulas, + "latex_content": "\n".join([f.get("latex", "") for f in formulas]), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化表格识别结果""" + tables = result.get("tables", []) + + return { + "recognition_mode": "table", + "table_count": len(tables), + "tables": tables, + "excel_data": result.get("excel_data"), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化文档识别结果""" + return { + "recognition_mode": "document", + "layout_info": result.get("layout_info", {}), + "text_blocks": result.get("text_blocks", []), + "image_blocks": result.get("image_blocks", []), + "table_blocks": result.get("table_blocks", []), + "full_text": result.get("full_text", ""), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """将行分组为段落""" + paragraphs = [] + current_paragraph = [] + + for line in lines: + text = line.get("text", "").strip() + if text: + current_paragraph.append(line) + else: + if current_paragraph: + paragraphs.append({ + "text": " ".join([l.get("text", "") for l in current_paragraph]), + "lines": current_paragraph + }) + current_paragraph = [] + + if current_paragraph: + paragraphs.append({ + "text": " ".join([l.get("text", "") for l in current_paragraph]), + "lines": current_paragraph + }) + + return paragraphs + + async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: + """调用TextIn API""" + app_id = self.get_config_parameter("app_id") + secret_key = self.get_config_parameter("secret_key") + api_url = self.get_config_parameter("api_url") + + if not app_id or not secret_key or not api_url: + raise ValueError("TextIn API配置未完成") + + # 构建完整URL + url = f"{api_url.rstrip('/')}/{endpoint}" + + # 构建请求头 + headers = { + "X-App-Id": app_id, + "X-Secret-Key": secret_key, + "Content-Type": "application/json" + } + + # 发送请求 + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + result = await response.json() + if result.get("code") == 200: + return result.get("data", result) + else: + raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}") + else: + error_text = await response.text() + raise Exception(f"HTTP错误 {response.status}: {error_text}") + + def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + app_id = self.get_config_parameter("app_id") + secret_key = self.get_config_parameter("secret_key") + api_url = self.get_config_parameter("api_url") + + if not app_id or not secret_key or not api_url: + return { + "success": False, + "error": "API配置未完成" + } + + return { + "success": True, + "message": "连接配置有效", + "api_url": api_url, + "app_id": app_id, + "secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/chain_manager.py b/api/app/core/tools/chain_manager.py new file mode 100644 index 00000000..713baa39 --- /dev/null +++ b/api/app/core/tools/chain_manager.py @@ -0,0 +1,485 @@ +"""工具链管理器 - 支持langchain的工具链模式""" +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +from enum import Enum + +from app.core.tools.base import ToolResult +from app.core.tools.executor import ToolExecutor +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ChainExecutionMode(str, Enum): + """链执行模式""" + SEQUENTIAL = "sequential" # 顺序执行 + PARALLEL = "parallel" # 并行执行 + CONDITIONAL = "conditional" # 条件执行 + + +@dataclass +class ChainStep: + """链步骤定义""" + tool_id: str + parameters: Dict[str, Any] + condition: Optional[str] = None # 执行条件 + output_mapping: Optional[Dict[str, str]] = None # 输出映射 + error_handling: str = "stop" # 错误处理:stop, continue, retry + + +@dataclass +class ChainDefinition: + """工具链定义""" + name: str + description: str + steps: List[ChainStep] + execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL + global_timeout: Optional[float] = None + retry_policy: Optional[Dict[str, Any]] = None + + +class ChainExecutionContext: + """链执行上下文""" + + def __init__(self, chain_id: str): + self.chain_id = chain_id + self.variables: Dict[str, Any] = {} + self.step_results: Dict[int, ToolResult] = {} + self.current_step = 0 + self.is_completed = False + self.is_failed = False + self.error_message: Optional[str] = None + + +class ChainManager: + """工具链管理器 - 支持langchain的工具链模式""" + + def __init__(self, executor: ToolExecutor): + """初始化工具链管理器 + + Args: + executor: 工具执行器 + """ + self.executor = executor + self._chains: Dict[str, ChainDefinition] = {} + self._running_chains: Dict[str, ChainExecutionContext] = {} + + def register_chain(self, chain: ChainDefinition) -> bool: + """注册工具链 + + Args: + chain: 工具链定义 + + Returns: + 注册是否成功 + """ + try: + # 验证工具链定义 + validation_result = self._validate_chain(chain) + if not validation_result[0]: + logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}") + return False + + self._chains[chain.name] = chain + logger.info(f"工具链注册成功: {chain.name}") + return True + + except Exception as e: + logger.error(f"工具链注册失败: {chain.name}, 错误: {e}") + return False + + def unregister_chain(self, chain_name: str) -> bool: + """注销工具链 + + Args: + chain_name: 工具链名称 + + Returns: + 注销是否成功 + """ + if chain_name in self._chains: + del self._chains[chain_name] + logger.info(f"工具链注销成功: {chain_name}") + return True + + return False + + def list_chains(self) -> List[Dict[str, Any]]: + """列出所有工具链 + + Returns: + 工具链信息列表 + """ + chains = [] + for name, chain in self._chains.items(): + chains.append({ + "name": name, + "description": chain.description, + "step_count": len(chain.steps), + "execution_mode": chain.execution_mode.value, + "global_timeout": chain.global_timeout + }) + + return chains + + async def execute_chain( + self, + chain_name: str, + initial_variables: Optional[Dict[str, Any]] = None, + chain_id: Optional[str] = None + ) -> Dict[str, Any] | None: + """执行工具链 + + Args: + chain_name: 工具链名称 + initial_variables: 初始变量 + chain_id: 链执行ID(可选) + + Returns: + 执行结果 + """ + if chain_name not in self._chains: + return { + "success": False, + "error": f"工具链不存在: {chain_name}", + "chain_id": chain_id + } + + chain = self._chains[chain_name] + + # 生成链ID + if not chain_id: + import uuid + chain_id = f"chain_{uuid.uuid4().hex[:16]}" + + # 创建执行上下文 + context = ChainExecutionContext(chain_id) + context.variables = initial_variables or {} + self._running_chains[chain_id] = context + + try: + logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})") + + # 根据执行模式执行 + if chain.execution_mode == ChainExecutionMode.SEQUENTIAL: + result = await self._execute_sequential(chain, context) + elif chain.execution_mode == ChainExecutionMode.PARALLEL: + result = await self._execute_parallel(chain, context) + elif chain.execution_mode == ChainExecutionMode.CONDITIONAL: + result = await self._execute_conditional(chain, context) + else: + raise ValueError(f"不支持的执行模式: {chain.execution_mode}") + + logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})") + return result + + except Exception as e: + logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}") + return { + "success": False, + "error": str(e), + "chain_id": chain_id, + "completed_steps": context.current_step, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + finally: + # 清理执行上下文 + if chain_id in self._running_chains: + del self._running_chains[chain_id] + + async def _execute_sequential( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """顺序执行工具链""" + for i, step in enumerate(chain.steps): + context.current_step = i + + # 检查执行条件 + if step.condition and not self._evaluate_condition(step.condition, context): + logger.debug(f"跳过步骤 {i}: 条件不满足") + continue + + # 准备参数 + parameters = self._prepare_parameters(step.parameters, context) + + # 执行工具 + try: + result = await self.executor.execute_tool( + tool_id=step.tool_id, + parameters=parameters + ) + + context.step_results[i] = result + + # 处理输出映射 + if step.output_mapping and result.success: + self._apply_output_mapping(step.output_mapping, result.data, context) + + # 处理执行失败 + if not result.success: + if step.error_handling == "stop": + context.is_failed = True + context.error_message = result.error + break + elif step.error_handling == "continue": + logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}") + continue + elif step.error_handling == "retry": + # 简单重试逻辑 + retry_result = await self.executor.execute_tool( + tool_id=step.tool_id, + parameters=parameters + ) + context.step_results[i] = retry_result + if not retry_result.success and step.error_handling == "stop": + context.is_failed = True + context.error_message = retry_result.error + break + + except Exception as e: + logger.error(f"步骤 {i} 执行异常: {e}") + if step.error_handling == "stop": + context.is_failed = True + context.error_message = str(e) + break + + context.is_completed = not context.is_failed + + return { + "success": context.is_completed, + "error": context.error_message, + "chain_id": context.chain_id, + "completed_steps": context.current_step + 1, + "total_steps": len(chain.steps), + "final_variables": context.variables, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + async def _execute_parallel( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """并行执行工具链""" + # 准备所有步骤的执行配置 + execution_configs = [] + + for i, step in enumerate(chain.steps): + # 检查执行条件 + if step.condition and not self._evaluate_condition(step.condition, context): + continue + + parameters = self._prepare_parameters(step.parameters, context) + execution_configs.append({ + "step_index": i, + "tool_id": step.tool_id, + "parameters": parameters + }) + + # 并行执行所有步骤 + try: + results = await self.executor.execute_tools_batch(execution_configs) + + # 处理结果 + for i, result in enumerate(results): + step_index = execution_configs[i]["step_index"] + context.step_results[step_index] = result + + # 处理输出映射 + step = chain.steps[step_index] + if step.output_mapping and result.success: + self._apply_output_mapping(step.output_mapping, result.data, context) + + # 检查是否有失败的步骤 + failed_steps = [i for i, result in context.step_results.items() if not result.success] + + context.is_completed = len(failed_steps) == 0 + if failed_steps: + context.error_message = f"步骤 {failed_steps} 执行失败" + + except Exception as e: + context.is_failed = True + context.error_message = str(e) + + return { + "success": context.is_completed, + "error": context.error_message, + "chain_id": context.chain_id, + "completed_steps": len(context.step_results), + "total_steps": len(chain.steps), + "final_variables": context.variables, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + async def _execute_conditional( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """条件执行工具链""" + # 条件执行类似于顺序执行,但更严格地检查条件 + return await self._execute_sequential(chain, context) + + def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]: + """验证工具链定义 + + Args: + chain: 工具链定义 + + Returns: + (是否有效, 错误信息) + """ + if not chain.name: + return False, "工具链名称不能为空" + + if not chain.steps: + return False, "工具链必须包含至少一个步骤" + + for i, step in enumerate(chain.steps): + if not step.tool_id: + return False, f"步骤 {i} 缺少工具ID" + + if step.error_handling not in ["stop", "continue", "retry"]: + return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}" + + return True, None + + def _prepare_parameters( + self, + parameters: Dict[str, Any], + context: ChainExecutionContext + ) -> Dict[str, Any]: + """准备参数(支持变量替换) + + Args: + parameters: 原始参数 + context: 执行上下文 + + Returns: + 处理后的参数 + """ + prepared = {} + + for key, value in parameters.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + # 变量替换 + var_name = value[2:-1] + if var_name in context.variables: + prepared[key] = context.variables[var_name] + else: + prepared[key] = value # 保持原值 + else: + prepared[key] = value + + return prepared + + def _evaluate_condition( + self, + condition: str, + context: ChainExecutionContext + ) -> bool: + """评估执行条件 + + Args: + condition: 条件表达式 + context: 执行上下文 + + Returns: + 条件是否满足 + """ + try: + # 简单的条件评估(可以扩展为更复杂的表达式解析) + # 支持格式:variable == value, variable != value, variable > value 等 + + if "==" in condition: + var_name, expected_value = condition.split("==", 1) + var_name = var_name.strip() + expected_value = expected_value.strip().strip('"\'') + + return str(context.variables.get(var_name, "")) == expected_value + + elif "!=" in condition: + var_name, expected_value = condition.split("!=", 1) + var_name = var_name.strip() + expected_value = expected_value.strip().strip('"\'') + + return str(context.variables.get(var_name, "")) != expected_value + + elif condition in context.variables: + # 简单的布尔检查 + return bool(context.variables[condition]) + + else: + # 默认为真 + return True + + except Exception as e: + logger.error(f"条件评估失败: {condition}, 错误: {e}") + return False + + def _apply_output_mapping( + self, + mapping: Dict[str, str], + output_data: Any, + context: ChainExecutionContext + ): + """应用输出映射 + + Args: + mapping: 输出映射配置 + output_data: 输出数据 + context: 执行上下文 + """ + try: + if isinstance(output_data, dict): + for source_key, target_var in mapping.items(): + if source_key in output_data: + context.variables[target_var] = output_data[source_key] + else: + # 如果输出不是字典,将整个输出映射到指定变量 + if "result" in mapping: + context.variables[mapping["result"]] = output_data + + except Exception as e: + logger.error(f"输出映射失败: {e}") + + def _serialize_result(self, result: ToolResult) -> Dict[str, Any]: + """序列化工具结果 + + Args: + result: 工具结果 + + Returns: + 序列化的结果 + """ + return { + "success": result.success, + "data": result.data, + "error": result.error, + "error_code": result.error_code, + "execution_time": result.execution_time, + "token_usage": result.token_usage, + "metadata": result.metadata + } + + def get_running_chains(self) -> List[Dict[str, Any]]: + """获取正在运行的工具链 + + Returns: + 运行中的工具链列表 + """ + chains = [] + for chain_id, context in self._running_chains.items(): + chains.append({ + "chain_id": chain_id, + "current_step": context.current_step, + "is_completed": context.is_completed, + "is_failed": context.is_failed, + "variables_count": len(context.variables), + "completed_steps": len(context.step_results) + }) + + return chains \ No newline at end of file diff --git a/api/app/core/tools/config_manager.py b/api/app/core/tools/config_manager.py new file mode 100644 index 00000000..fb8d1fff --- /dev/null +++ b/api/app/core/tools/config_manager.py @@ -0,0 +1,264 @@ +"""工具配置管理器 - 管理工具配置的加载和验证""" +import json +from pathlib import Path +from typing import Dict, Any, Optional +from pydantic import BaseModel, ValidationError + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ToolConfigSchema(BaseModel): + """工具配置基础Schema""" + name: str + description: str + tool_type: str + version: str = "1.0.0" + enabled: bool = True + parameters: Dict[str, Any] = {} + tags: list[str] = [] + + class Config: + extra = "allow" + + +class BuiltinToolConfigSchema(ToolConfigSchema): + """内置工具配置Schema""" + tool_class: str + tool_type: str = "builtin" + + +class CustomToolConfigSchema(ToolConfigSchema): + """自定义工具配置Schema""" + schema_url: Optional[str] = None + schema_content: Optional[Dict[str, Any]] = None + auth_type: str = "none" + auth_config: Dict[str, Any] = {} + base_url: Optional[str] = None + timeout: int = 30 + tool_type: str = "custom" + + +class MCPToolConfigSchema(ToolConfigSchema): + """MCP工具配置Schema""" + server_url: str + connection_config: Dict[str, Any] = {} + available_tools: list[str] = [] + tool_type: str = "mcp" + + +class ConfigManager: + """工具配置管理器""" + + def __init__(self, config_dir: Optional[str] = None): + """初始化配置管理器 + + Args: + config_dir: 配置文件目录,默认使用系统配置 + """ + self.config_dir = Path(config_dir or self._get_default_config_dir()) + self.config_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}") + + def _get_default_config_dir(self) -> str: + """获取默认配置目录""" + # 获取tools目录下的configs子目录 + tools_dir = Path(__file__).parent + return str(tools_dir / "configs") + + def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]: + """加载内置工具配置 + + Returns: + 内置工具配置字典 + """ + configs = {} + builtin_dir = self.config_dir / "builtin" + + if not builtin_dir.exists(): + logger.info("内置工具配置目录不存在,创建默认配置") + self._create_default_builtin_configs(builtin_dir) + + for config_file in builtin_dir.glob("*.json"): + try: + config_data = self._load_config_file(config_file) + config = BuiltinToolConfigSchema(**config_data) + configs[config.name] = config + logger.debug(f"加载内置工具配置: {config.name}") + except Exception as e: + logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}") + + return configs + + def load_builtin_tools_config(self) -> Dict[str, Any]: + """加载全局内置工具配置(兼容原有接口) + + Returns: + 内置工具配置字典 + """ + config_file = self.config_dir / "builtin_tools.json" + try: + with open(config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"加载内置工具配置失败: {e}") + return {} + + def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum): + """确保内置工具已初始化到数据库 + + Args: + tenant_id: 租户ID + db_session: 数据库会话 + tool_config_model: ToolConfig模型类 + builtin_tool_config_model: BuiltinToolConfig模型类 + tool_type_enum: ToolType枚举 + tool_status_enum: ToolStatus枚举 + """ + # 检查是否已初始化 + existing_count = db_session.query(tool_config_model).filter( + tool_config_model.tenant_id == tenant_id, + tool_config_model.tool_type == tool_type_enum.BUILTIN + ).count() + + if existing_count > 0: + return # 已初始化 + + # 加载全局配置 + builtin_tools = self.load_builtin_tools_config() + + # 为租户创建内置工具记录 + for tool_key, tool_info in builtin_tools.items(): + # 设置初始状态 + initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value + + tool_config = tool_config_model( + name=tool_info['name'], + description=tool_info['description'], + tool_type=tool_type_enum.BUILTIN, + tenant_id=tenant_id, + status=initial_status + ) + db_session.add(tool_config) + db_session.flush() + + builtin_config = builtin_tool_config_model( + id=tool_config.id, + tool_class=tool_info['tool_class'], + parameters={} + ) + db_session.add(builtin_config) + + db_session.commit() + logger.info(f"租户 {tenant_id} 的内置工具初始化完成") + + def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool: + """保存工具配置 + + Args: + config: 工具配置 + tool_type: 工具类型 + + Returns: + 保存是否成功 + """ + try: + config_dir = self.config_dir / tool_type + config_dir.mkdir(parents=True, exist_ok=True) + + config_file = config_dir / f"{config.name}.json" + config_data = config.model_dump() + + with open(config_file, 'w', encoding='utf-8') as f: + json.dump(config_data, f, indent=2, ensure_ascii=False) + + logger.info(f"工具配置保存成功: {config.name} ({tool_type})") + return True + + except Exception as e: + logger.error(f"工具配置保存失败: {config.name}, 错误: {e}") + return False + + def delete_tool_config(self, tool_name: str, tool_type: str) -> bool: + """删除工具配置 + + Args: + tool_name: 工具名称 + tool_type: 工具类型 + + Returns: + 删除是否成功 + """ + try: + config_file = self.config_dir / tool_type / f"{tool_name}.json" + + if config_file.exists(): + config_file.unlink() + logger.info(f"工具配置删除成功: {tool_name} ({tool_type})") + return True + else: + logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})") + return False + + except Exception as e: + logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}") + return False + + def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]: + """验证工具配置 + + Args: + config_data: 配置数据 + tool_type: 工具类型 + + Returns: + (是否有效, 错误信息) + """ + try: + schema_map = { + "builtin": BuiltinToolConfigSchema, + "custom": CustomToolConfigSchema, + "mcp": MCPToolConfigSchema + } + + schema_class = schema_map.get(tool_type) + if not schema_class: + return False, f"不支持的工具类型: {tool_type}" + + # 验证配置 + schema_class(**config_data) + return True, None + + except ValidationError as e: + error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()]) + return False, f"配置验证失败: {error_msg}" + except Exception as e: + return False, f"配置验证异常: {str(e)}" + + def _load_config_file(self, config_file: Path) -> Dict[str, Any]: + """加载配置文件 + + Args: + config_file: 配置文件路径 + + Returns: + 配置数据字典 + """ + try: + with open(config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"加载配置文件失败: {config_file}, 错误: {e}") + raise + + def _create_default_builtin_configs(self, builtin_dir: Path): + """创建默认内置工具配置 + + Args: + builtin_dir: 内置工具配置目录 + """ + builtin_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"内置工具配置目录已创建: {builtin_dir}") + # 配置文件已经通过其他方式创建,这里只需要确保目录存在 \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/baidu_search_tool.json b/api/app/core/tools/configs/builtin/baidu_search_tool.json new file mode 100644 index 00000000..e46a34e3 --- /dev/null +++ b/api/app/core/tools/configs/builtin/baidu_search_tool.json @@ -0,0 +1,14 @@ +{ + "name": "baidu_search_tool", + "description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能", + "tool_type": "builtin", + "tool_class": "BaiduSearchTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": "", + "secret_key": "", + "search_type": "web" + }, + "tags": ["search", "web", "baidu", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/datetime_tool.json b/api/app/core/tools/configs/builtin/datetime_tool.json new file mode 100644 index 00000000..8652fd05 --- /dev/null +++ b/api/app/core/tools/configs/builtin/datetime_tool.json @@ -0,0 +1,12 @@ +{ + "name": "datetime_tool", + "description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算", + "tool_type": "builtin", + "tool_class": "DateTimeTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "timezone": "UTC" + }, + "tags": ["time", "utility", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/json_tool.json b/api/app/core/tools/configs/builtin/json_tool.json new file mode 100644 index 00000000..4c9f8c4a --- /dev/null +++ b/api/app/core/tools/configs/builtin/json_tool.json @@ -0,0 +1,12 @@ +{ + "name": "json_tool", + "description": "JSON工具 - 数据格式处理:提供JSON格式化、压缩、验证、格式转换", + "tool_type": "builtin", + "tool_class": "JsonTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "indent": 2 + }, + "tags": ["json", "data", "utility", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/mineru_tool.json b/api/app/core/tools/configs/builtin/mineru_tool.json new file mode 100644 index 00000000..e53d6a71 --- /dev/null +++ b/api/app/core/tools/configs/builtin/mineru_tool.json @@ -0,0 +1,14 @@ +{ + "name": "mineru_tool", + "description": "MinerU PDF解析工具 - 文档处理:提供PDF解析、表格提取、图片识别、文本提取功能", + "tool_type": "builtin", + "tool_class": "MinerUTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": "", + "parse_mode": "auto", + "timeout": 60 + }, + "tags": ["pdf", "document", "ocr", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/textin_tool.json b/api/app/core/tools/configs/builtin/textin_tool.json new file mode 100644 index 00000000..d954f8f1 --- /dev/null +++ b/api/app/core/tools/configs/builtin/textin_tool.json @@ -0,0 +1,14 @@ +{ + "name": "textin_tool", + "description": "TextIn OCR工具 - 图像识别:提供通用OCR、手写识别、多语言支持功能", + "tool_type": "builtin", + "tool_class": "TextInTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "app_id": "", + "language": "auto", + "recognition_mode": "general" + }, + "tags": ["ocr", "image", "text", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin_tools.json b/api/app/core/tools/configs/builtin_tools.json new file mode 100644 index 00000000..ed0b87b1 --- /dev/null +++ b/api/app/core/tools/configs/builtin_tools.json @@ -0,0 +1,60 @@ +{ + "datetime": { + "name": "时间工具", + "description": "获取当前时间、日期计算", + "tool_class": "DateTimeTool", + "category": "utility", + "requires_config": false, + "version": "1.0.0", + "enabled": true, + "parameters": {} + }, + "json_converter": { + "name": "JSON转换工具", + "description": "JSON数据格式化和转换", + "tool_class": "JsonTool", + "category": "utility", + "requires_config": false, + "version": "1.0.0", + "enabled": true, + "parameters": {} + }, + "baidu_search": { + "name": "百度搜索", + "description": "百度网页搜索服务", + "tool_class": "BaiduSearchTool", + "category": "search", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true} + } + }, + "mineru": { + "name": "MinerU", + "description": "PDF文档解析工具", + "tool_class": "MinerUTool", + "category": "document", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true}, + "base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"} + } + }, + "textin": { + "name": "TextIn", + "description": "OCR文字识别服务", + "tool_class": "TextInTool", + "category": "ocr", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}, + "api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true} + } + } +} \ No newline at end of file diff --git a/api/app/core/tools/custom/__init__.py b/api/app/core/tools/custom/__init__.py new file mode 100644 index 00000000..87b0488a --- /dev/null +++ b/api/app/core/tools/custom/__init__.py @@ -0,0 +1,11 @@ +"""自定义工具模块""" + +from .base import CustomTool +from .schema_parser import OpenAPISchemaParser +from .auth_manager import AuthManager + +__all__ = [ + "CustomTool", + "OpenAPISchemaParser", + "AuthManager" +] \ No newline at end of file diff --git a/api/app/core/tools/custom/auth_manager.py b/api/app/core/tools/custom/auth_manager.py new file mode 100644 index 00000000..5d457f11 --- /dev/null +++ b/api/app/core/tools/custom/auth_manager.py @@ -0,0 +1,525 @@ +"""认证管理器 - 处理自定义工具的认证配置""" +import base64 +import hashlib +import hmac +import time +from typing import Dict, Any, Tuple +from urllib.parse import quote +import aiohttp + +from app.models.tool_model import AuthType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class AuthManager: + """认证管理器 - 支持多种认证方式""" + + def __init__(self): + """初始化认证管理器""" + self.supported_auth_types = [ + AuthType.NONE, + AuthType.API_KEY, + AuthType.BEARER_TOKEN + ] + + def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证认证配置 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + try: + if auth_type not in self.supported_auth_types: + return False, f"不支持的认证类型: {auth_type}" + + if auth_type == AuthType.NONE: + return True, "" + + elif auth_type == AuthType.API_KEY: + return self._validate_api_key_config(auth_config) + + elif auth_type == AuthType.BEARER_TOKEN: + return self._validate_bearer_token_config(auth_config) + + return False, "未知的认证类型" + + except Exception as e: + return False, f"验证认证配置时出错: {e}" + + def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证API Key认证配置 + + Args: + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + api_key = auth_config.get("api_key") + if not api_key: + return False, "API Key不能为空" + + if not isinstance(api_key, str): + return False, "API Key必须是字符串" + + # 验证key名称 + key_name = auth_config.get("key_name", "X-API-Key") + if not isinstance(key_name, str): + return False, "API Key名称必须是字符串" + + # 验证位置 + key_location = auth_config.get("location", "header") + if key_location not in ["header", "query", "cookie"]: + return False, "API Key位置必须是 header、query 或 cookie" + + return True, "" + + def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证Bearer Token认证配置 + + Args: + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + token = auth_config.get("token") + if not token: + return False, "Bearer Token不能为空" + + if not isinstance(token, str): + return False, "Bearer Token必须是字符串" + + return True, "" + + def apply_authentication( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用认证到请求 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + try: + if auth_type == AuthType.NONE: + return url, headers, params + + elif auth_type == AuthType.API_KEY: + return self._apply_api_key_auth(auth_config, url, headers, params) + + elif auth_type == AuthType.BEARER_TOKEN: + return self._apply_bearer_token_auth(auth_config, url, headers, params) + + else: + logger.warning(f"不支持的认证类型: {auth_type}") + return url, headers, params + + except Exception as e: + logger.error(f"应用认证时出错: {e}") + return url, headers, params + + def _apply_api_key_auth( + self, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用API Key认证 + + Args: + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + api_key = auth_config.get("api_key") + key_name = auth_config.get("key_name", "X-API-Key") + location = auth_config.get("location", "header") + + if location == "header": + headers[key_name] = api_key + + elif location == "query": + # 添加到URL查询参数 + separator = "&" if "?" in url else "?" + encoded_key = quote(str(api_key)) + url += f"{separator}{key_name}={encoded_key}" + + elif location == "cookie": + # 添加到Cookie头 + cookie_value = f"{key_name}={api_key}" + if "Cookie" in headers: + headers["Cookie"] += f"; {cookie_value}" + else: + headers["Cookie"] = cookie_value + + return url, headers, params + + def _apply_bearer_token_auth( + self, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用Bearer Token认证 + + Args: + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + token = auth_config.get("token") + headers["Authorization"] = f"Bearer {token}" + + return url, headers, params + + def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]: + """加密认证配置中的敏感信息 + + Args: + auth_config: 认证配置 + encryption_key: 加密密钥 + + Returns: + 加密后的认证配置 + """ + try: + encrypted_config = auth_config.copy() + + # 需要加密的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in encrypted_config: + value = encrypted_config[field] + if isinstance(value, str) and value: + encrypted_value = self._encrypt_string(value, encryption_key) + encrypted_config[field] = encrypted_value + encrypted_config[f"{field}_encrypted"] = True + + return encrypted_config + + except Exception as e: + logger.error(f"加密认证配置失败: {e}") + return auth_config + + def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]: + """解密认证配置中的敏感信息 + + Args: + encrypted_config: 加密的认证配置 + encryption_key: 解密密钥 + + Returns: + 解密后的认证配置 + """ + try: + decrypted_config = encrypted_config.copy() + + # 需要解密的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"): + encrypted_value = decrypted_config[field] + if isinstance(encrypted_value, str) and encrypted_value: + decrypted_value = self._decrypt_string(encrypted_value, encryption_key) + decrypted_config[field] = decrypted_value + # 移除加密标记 + decrypted_config.pop(f"{field}_encrypted", None) + + return decrypted_config + + except Exception as e: + logger.error(f"解密认证配置失败: {e}") + return encrypted_config + + def _encrypt_string(self, value: str, key: str) -> str: + """加密字符串 + + Args: + value: 要加密的字符串 + key: 加密密钥 + + Returns: + 加密后的字符串(Base64编码) + """ + try: + # 使用HMAC-SHA256进行简单加密 + key_bytes = key.encode('utf-8') + value_bytes = value.encode('utf-8') + + # 生成HMAC + hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256) + signature = hmac_obj.hexdigest() + + # 组合原始值和签名,然后Base64编码 + combined = f"{value}:{signature}" + encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8') + + return encrypted + + except Exception as e: + logger.error(f"加密字符串失败: {e}") + return value + + def _decrypt_string(self, encrypted_value: str, key: str) -> str: + """解密字符串 + + Args: + encrypted_value: 加密的字符串 + key: 解密密钥 + + Returns: + 解密后的字符串 + """ + try: + # Base64解码 + decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8') + + # 分离原始值和签名 + if ':' not in decoded: + return encrypted_value # 可能不是加密的值 + + value, signature = decoded.rsplit(':', 1) + + # 验证签名 + key_bytes = key.encode('utf-8') + value_bytes = value.encode('utf-8') + + hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256) + expected_signature = hmac_obj.hexdigest() + + if signature == expected_signature: + return value + else: + logger.warning("解密时签名验证失败") + return encrypted_value + + except Exception as e: + logger.error(f"解密字符串失败: {e}") + return encrypted_value + + def test_authentication( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + test_url: str = None + ) -> Dict[str, Any]: + """测试认证配置 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + test_url: 测试URL(可选) + + Returns: + 测试结果 + """ + try: + # 验证配置 + is_valid, error_msg = self.validate_auth_config(auth_type, auth_config) + if not is_valid: + return { + "success": False, + "error": error_msg, + "auth_type": auth_type.value + } + + # 如果没有测试URL,只验证配置 + if not test_url: + return { + "success": True, + "message": "认证配置有效", + "auth_type": auth_type.value + } + + # 构建测试请求 + headers = {"User-Agent": "AuthManager-Test/1.0"} + params = {} + + # 应用认证 + test_url, headers, params = self.apply_authentication( + auth_type, auth_config, test_url, headers, params + ) + + return { + "success": True, + "message": "认证配置测试成功", + "auth_type": auth_type.value, + "test_url": test_url, + "headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息 + "has_auth_header": "Authorization" in headers + } + + except Exception as e: + return { + "success": False, + "error": str(e), + "auth_type": auth_type.value if auth_type else "unknown" + } + + async def test_authentication_with_request( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + test_url: str, + timeout: int = 10 + ) -> Dict[str, Any]: + """通过实际HTTP请求测试认证 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + test_url: 测试URL + timeout: 超时时间(秒) + + Returns: + 测试结果 + """ + try: + # 验证配置 + is_valid, error_msg = self.validate_auth_config(auth_type, auth_config) + if not is_valid: + return { + "success": False, + "error": error_msg, + "auth_type": auth_type.value + } + + # 构建请求 + headers = {"User-Agent": "AuthManager-Test/1.0"} + params = {} + + # 应用认证 + test_url, headers, params = self.apply_authentication( + auth_type, auth_config, test_url, headers, params + ) + + # 发送测试请求 + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.get(test_url, headers=headers) as response: + status_code = response.status + + # 根据状态码判断认证是否成功 + if status_code == 200: + return { + "success": True, + "message": "认证测试成功", + "status_code": status_code, + "auth_type": auth_type.value + } + elif status_code == 401: + return { + "success": False, + "error": "认证失败 - 401 Unauthorized", + "status_code": status_code, + "auth_type": auth_type.value + } + elif status_code == 403: + return { + "success": False, + "error": "认证失败 - 403 Forbidden", + "status_code": status_code, + "auth_type": auth_type.value + } + else: + return { + "success": True, + "message": f"请求成功,状态码: {status_code}", + "status_code": status_code, + "auth_type": auth_type.value + } + + except aiohttp.ClientError as e: + return { + "success": False, + "error": f"网络请求失败: {e}", + "auth_type": auth_type.value + } + except Exception as e: + return { + "success": False, + "error": f"测试认证时出错: {e}", + "auth_type": auth_type.value + } + + def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]: + """获取认证配置模板 + + Args: + auth_type: 认证类型 + + Returns: + 配置模板 + """ + templates = { + AuthType.NONE: {}, + + AuthType.API_KEY: { + "api_key": "", + "key_name": "X-API-Key", + "location": "header", # header, query, cookie + "description": "API Key认证配置" + }, + + AuthType.BEARER_TOKEN: { + "token": "", + "description": "Bearer Token认证配置" + } + } + + return templates.get(auth_type, {}) + + def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]: + """遮蔽认证配置中的敏感信息 + + Args: + auth_config: 认证配置 + + Returns: + 遮蔽敏感信息后的配置 + """ + masked_config = auth_config.copy() + + # 需要遮蔽的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in masked_config: + value = masked_config[field] + if isinstance(value, str) and len(value) > 4: + # 只显示前2位和后2位 + masked_config[field] = f"{value[:2]}***{value[-2:]}" + elif isinstance(value, str) and value: + masked_config[field] = "***" + + return masked_config \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py new file mode 100644 index 00000000..eda6769b --- /dev/null +++ b/api/app/core/tools/custom/base.py @@ -0,0 +1,318 @@ +"""自定义工具基类""" +import time +from typing import Dict, Any, List, Optional +import aiohttp +from urllib.parse import urljoin + +from app.models.tool_model import ToolType, AuthType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class CustomTool(BaseTool): + """自定义工具 - 基于OpenAPI schema的工具""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化自定义工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.schema_content = config.get("schema_content", {}) + self.schema_url = config.get("schema_url") + self.auth_type = AuthType(config.get("auth_type", "none")) + self.auth_config = config.get("auth_config", {}) + self.base_url = config.get("base_url", "") + self.timeout = config.get("timeout", 30) + + # 解析schema + self._parsed_operations = self._parse_openapi_schema() + + @property + def name(self) -> str: + """工具名称""" + if self.schema_content: + info = self.schema_content.get("info", {}) + return info.get("title", f"custom_tool_{self.tool_id[:8]}") + return f"custom_tool_{self.tool_id[:8]}" + + @property + def description(self) -> str: + """工具描述""" + if self.schema_content: + info = self.schema_content.get("info", {}) + return info.get("description", "自定义API工具") + return "自定义API工具" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.CUSTOM + + @property + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + params = [] + + # 添加操作选择参数 + if len(self._parsed_operations) > 1: + params.append(ToolParameter( + name="operation", + type=ParameterType.STRING, + description="要执行的操作", + required=True, + enum=list(self._parsed_operations.keys()) + )) + + # 添加通用参数(基于第一个操作的参数) + if self._parsed_operations: + first_operation = next(iter(self._parsed_operations.values())) + for param_name, param_info in first_operation.get("parameters", {}).items(): + params.append(ToolParameter( + name=param_name, + type=self._convert_openapi_type(param_info.get("type", "string")), + description=param_info.get("description", ""), + required=param_info.get("required", False), + default=param_info.get("default"), + enum=param_info.get("enum"), + minimum=param_info.get("minimum"), + maximum=param_info.get("maximum"), + pattern=param_info.get("pattern") + )) + + return params + + async def execute(self, **kwargs) -> ToolResult: + """执行自定义工具""" + start_time = time.time() + + try: + # 确定要执行的操作 + operation_name = kwargs.get("operation") + if not operation_name and len(self._parsed_operations) == 1: + operation_name = next(iter(self._parsed_operations.keys())) + + if not operation_name or operation_name not in self._parsed_operations: + raise ValueError(f"无效的操作: {operation_name}") + + operation = self._parsed_operations[operation_name] + + # 构建请求 + url = self._build_request_url(operation, kwargs) + headers = self._build_request_headers(operation) + data = self._build_request_data(operation, kwargs) + + # 发送HTTP请求 + result = await self._send_http_request( + method=operation["method"], + url=url, + headers=headers, + data=data + ) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="CUSTOM_TOOL_ERROR", + execution_time=execution_time + ) + + def _parse_openapi_schema(self) -> Dict[str, Any]: + """解析OpenAPI schema""" + operations = {} + + if not self.schema_content: + return operations + + paths = self.schema_content.get("paths", {}) + + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method.lower() in ["get", "post", "put", "delete", "patch"]: + operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}") + + # 解析参数 + parameters = {} + if "parameters" in operation: + for param in operation["parameters"]: + param_name = param.get("name") + param_schema = param.get("schema", {}) + parameters[param_name] = { + "type": param_schema.get("type", "string"), + "description": param.get("description", ""), + "required": param.get("required", False), + "in": param.get("in", "query"), + **param_schema + } + + # 解析请求体 + request_body = None + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}) + if "application/json" in content: + request_body = content["application/json"].get("schema", {}) + + operations[operation_id] = { + "method": method.upper(), + "path": path, + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": parameters, + "request_body": request_body + } + + return operations + + def _convert_openapi_type(self, openapi_type: str) -> ParameterType: + """转换OpenAPI类型到内部类型""" + type_mapping = { + "string": ParameterType.STRING, + "integer": ParameterType.INTEGER, + "number": ParameterType.NUMBER, + "boolean": ParameterType.BOOLEAN, + "array": ParameterType.ARRAY, + "object": ParameterType.OBJECT + } + return type_mapping.get(openapi_type, ParameterType.STRING) + + def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str: + """构建请求URL""" + path = operation["path"] + + # 替换路径参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("in") == "path" and param_name in params: + path = path.replace(f"{{{param_name}}}", str(params[param_name])) + + # 构建完整URL + if self.base_url: + url = urljoin(self.base_url, path.lstrip("/")) + else: + # 从schema中获取服务器URL + servers = self.schema_content.get("servers", []) + if servers: + base_url = servers[0].get("url", "") + url = urljoin(base_url, path.lstrip("/")) + else: + url = path + + # 添加查询参数 + query_params = {} + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("in") == "query" and param_name in params: + query_params[param_name] = params[param_name] + + if query_params: + from urllib.parse import urlencode + url += "?" + urlencode(query_params) + + return url + + def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]: + """构建请求头""" + headers = { + "Content-Type": "application/json", + "User-Agent": "CustomTool/1.0" + } + + # 添加认证头 + if self.auth_type == AuthType.API_KEY: + api_key = self.auth_config.get("api_key") + key_name = self.auth_config.get("key_name", "X-API-Key") + if api_key: + headers[key_name] = api_key + + elif self.auth_type == AuthType.BEARER_TOKEN: + token = self.auth_config.get("token") + if token: + headers["Authorization"] = f"Bearer {token}" + + return headers + + def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """构建请求数据""" + if operation["method"] in ["POST", "PUT", "PATCH"]: + request_body = operation.get("request_body") + if request_body: + # 构建请求体数据 + data = {} + properties = request_body.get("properties", {}) + + for prop_name, prop_schema in properties.items(): + if prop_name in params: + data[prop_name] = params[prop_name] + + return data if data else None + + return None + + async def _send_http_request( + self, + method: str, + url: str, + headers: Dict[str, str], + data: Optional[Dict[str, Any]] = None + ) -> Any: + """发送HTTP请求""" + timeout = aiohttp.ClientTimeout(total=self.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + kwargs = { + "headers": headers + } + + if data and method in ["POST", "PUT", "PATCH"]: + kwargs["json"] = data + + async with session.request(method, url, **kwargs) as response: + if response.status >= 400: + error_text = await response.text() + raise Exception(f"HTTP {response.status}: {error_text}") + + # 尝试解析JSON响应 + try: + return await response.json() + except Exception as e: + return await response.text() + + @classmethod + def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool': + """从URL导入OpenAPI schema创建工具""" + import uuid + if not tool_id: + tool_id = str(uuid.uuid4()) + + config = { + "schema_url": schema_url, + "auth_config": auth_config, + "auth_type": auth_config.get("type", "none") + } + + # 这里应该异步加载schema,为了简化暂时返回空配置 + return cls(tool_id, config) + + @classmethod + def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool': + """从schema字典创建工具""" + import uuid + if not tool_id: + tool_id = str(uuid.uuid4()) + + config = { + "schema_content": schema_dict, + "auth_config": auth_config, + "auth_type": auth_config.get("type", "none") + } + + return cls(tool_id, config) \ No newline at end of file diff --git a/api/app/core/tools/custom/schema_parser.py b/api/app/core/tools/custom/schema_parser.py new file mode 100644 index 00000000..21ac28b6 --- /dev/null +++ b/api/app/core/tools/custom/schema_parser.py @@ -0,0 +1,477 @@ +"""OpenAPI Schema解析器""" +import json +import yaml +from typing import Dict, Any, List, Optional, Tuple +from urllib.parse import urlparse +import aiohttp +import asyncio + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class OpenAPISchemaParser: + """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" + + def __init__(self): + """初始化解析器""" + self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"] + + async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]: + """从URL解析OpenAPI schema + + Args: + schema_url: Schema URL + timeout: 超时时间(秒) + + Returns: + (是否成功, schema内容, 错误信息) + """ + try: + # 验证URL格式 + parsed_url = urlparse(schema_url) + if not parsed_url.scheme or not parsed_url.netloc: + return False, {}, "无效的URL格式" + + # 下载schema + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.get(schema_url) as response: + if response.status != 200: + return False, {}, f"HTTP错误: {response.status}" + + content_type = response.headers.get('content-type', '').lower() + content = await response.text() + + # 解析内容 + schema_dict = self._parse_content(content, content_type) + if not schema_dict: + return False, {}, "无法解析schema内容" + + # 验证schema + is_valid, error_msg = self.validate_schema(schema_dict) + if not is_valid: + return False, {}, error_msg + + return True, schema_dict, "" + + except asyncio.TimeoutError: + return False, {}, "请求超时" + except Exception as e: + logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}") + return False, {}, str(e) + + def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]: + """从内容解析OpenAPI schema + + Args: + content: Schema内容 + content_type: 内容类型 + + Returns: + (是否成功, schema内容, 错误信息) + """ + try: + # 解析内容 + schema_dict = self._parse_content(content, content_type) + if not schema_dict: + return False, {}, "无法解析schema内容" + + # 验证schema + is_valid, error_msg = self.validate_schema(schema_dict) + if not is_valid: + return False, {}, error_msg + + return True, schema_dict, "" + + except Exception as e: + logger.error(f"解析schema内容失败: {e}") + return False, {}, str(e) + + def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]: + """解析内容为字典 + + Args: + content: 内容字符串 + content_type: 内容类型 + + Returns: + 解析后的字典,失败返回None + """ + try: + # 根据内容类型解析 + if 'json' in content_type: + return json.loads(content) + elif 'yaml' in content_type or 'yml' in content_type: + return yaml.safe_load(content) + else: + # 尝试自动检测格式 + try: + return json.loads(content) + except json.JSONDecodeError: + try: + return yaml.safe_load(content) + except yaml.YAMLError: + return None + except Exception as e: + logger.error(f"解析内容失败: {e}") + return None + + def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]: + """验证OpenAPI schema + + Args: + schema_dict: Schema字典 + + Returns: + (是否有效, 错误信息) + """ + try: + # 检查基本结构 + if not isinstance(schema_dict, dict): + return False, "Schema必须是JSON对象" + + # 检查OpenAPI版本 + openapi_version = schema_dict.get("openapi") + if not openapi_version: + return False, "缺少openapi版本字段" + + if openapi_version not in self.supported_versions: + return False, f"不支持的OpenAPI版本: {openapi_version}" + + # 检查必需字段 + required_fields = ["info", "paths"] + for field in required_fields: + if field not in schema_dict: + return False, f"缺少必需字段: {field}" + + # 验证info字段 + info = schema_dict.get("info", {}) + if not isinstance(info, dict): + return False, "info字段必须是对象" + + if "title" not in info: + return False, "info.title字段是必需的" + + # 验证paths字段 + paths = schema_dict.get("paths", {}) + if not isinstance(paths, dict): + return False, "paths字段必须是对象" + + # 验证至少有一个路径 + if not paths: + return False, "至少需要定义一个API路径" + + return True, "" + + except Exception as e: + return False, f"验证schema时出错: {e}" + + def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """从schema提取工具信息 + + Args: + schema_dict: Schema字典 + + Returns: + 工具信息字典 + """ + info = schema_dict.get("info", {}) + + return { + "name": info.get("title", "Custom API Tool"), + "description": info.get("description", ""), + "version": info.get("version", "1.0.0"), + "servers": schema_dict.get("servers", []), + "operations": self._extract_operations(schema_dict) + } + + def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """提取API操作信息 + + Args: + schema_dict: Schema字典 + + Returns: + 操作信息字典 + """ + operations = {} + paths = schema_dict.get("paths", {}) + + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + + for method, operation in path_item.items(): + if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]: + continue + + if not isinstance(operation, dict): + continue + + # 生成操作ID + operation_id = operation.get("operationId") + if not operation_id: + operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}" + + # 提取操作信息 + operations[operation_id] = { + "method": method.upper(), + "path": path, + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": self._extract_parameters(operation), + "request_body": self._extract_request_body(operation), + "responses": self._extract_responses(operation), + "tags": operation.get("tags", []) + } + + return operations + + def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]: + """提取操作参数 + + Args: + operation: 操作定义 + + Returns: + 参数信息字典 + """ + parameters = {} + + for param in operation.get("parameters", []): + if not isinstance(param, dict): + continue + + param_name = param.get("name") + if not param_name: + continue + + param_schema = param.get("schema", {}) + + parameters[param_name] = { + "name": param_name, + "in": param.get("in", "query"), + "description": param.get("description", ""), + "required": param.get("required", False), + "type": param_schema.get("type", "string"), + "format": param_schema.get("format"), + "enum": param_schema.get("enum"), + "default": param_schema.get("default"), + "minimum": param_schema.get("minimum"), + "maximum": param_schema.get("maximum"), + "pattern": param_schema.get("pattern"), + "example": param.get("example") or param_schema.get("example") + } + + return parameters + + def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """提取请求体信息 + + Args: + operation: 操作定义 + + Returns: + 请求体信息,如果没有返回None + """ + request_body = operation.get("requestBody") + if not request_body: + return None + + content = request_body.get("content", {}) + + # 优先使用application/json + if "application/json" in content: + schema = content["application/json"].get("schema", {}) + elif content: + # 使用第一个可用的内容类型 + first_content_type = next(iter(content.keys())) + schema = content[first_content_type].get("schema", {}) + else: + return None + + return { + "description": request_body.get("description", ""), + "required": request_body.get("required", False), + "schema": schema, + "content_types": list(content.keys()) + } + + def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]: + """提取响应信息 + + Args: + operation: 操作定义 + + Returns: + 响应信息字典 + """ + responses = {} + + for status_code, response in operation.get("responses", {}).items(): + if not isinstance(response, dict): + continue + + content = response.get("content", {}) + schema = None + + # 尝试获取响应schema + if "application/json" in content: + schema = content["application/json"].get("schema") + elif content: + first_content_type = next(iter(content.keys())) + schema = content[first_content_type].get("schema") + + responses[status_code] = { + "description": response.get("description", ""), + "schema": schema, + "content_types": list(content.keys()) if content else [] + } + + return responses + + def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]: + """生成工具参数定义 + + Args: + operations: 操作信息字典 + + Returns: + 参数定义列表 + """ + parameters = [] + + # 如果有多个操作,添加操作选择参数 + if len(operations) > 1: + parameters.append({ + "name": "operation", + "type": "string", + "description": "要执行的操作", + "required": True, + "enum": list(operations.keys()) + }) + + # 收集所有参数(去重) + all_params = {} + + for operation_id, operation in operations.items(): + # 路径参数和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_name not in all_params: + all_params[param_name] = { + "name": param_name, + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + "required": param_info.get("required", False), + "enum": param_info.get("enum"), + "default": param_info.get("default"), + "minimum": param_info.get("minimum"), + "maximum": param_info.get("maximum"), + "pattern": param_info.get("pattern") + } + + # 请求体参数 + request_body = operation.get("request_body") + if request_body: + schema = request_body.get("schema", {}) + properties = schema.get("properties", {}) + + for prop_name, prop_schema in properties.items(): + if prop_name not in all_params: + all_params[prop_name] = { + "name": prop_name, + "type": prop_schema.get("type", "string"), + "description": prop_schema.get("description", ""), + "required": prop_name in schema.get("required", []), + "enum": prop_schema.get("enum"), + "default": prop_schema.get("default"), + "minimum": prop_schema.get("minimum"), + "maximum": prop_schema.get("maximum"), + "pattern": prop_schema.get("pattern") + } + + # 转换为参数列表 + parameters.extend(all_params.values()) + + return parameters + + def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]: + """验证操作参数 + + Args: + operation: 操作定义 + params: 输入参数 + + Returns: + (是否有效, 错误信息列表) + """ + errors = [] + + # 验证路径参数和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("required", False) and param_name not in params: + errors.append(f"缺少必需参数: {param_name}") + + if param_name in params: + value = params[param_name] + param_type = param_info.get("type", "string") + + # 类型验证 + if not self._validate_parameter_type(value, param_type): + errors.append(f"参数 {param_name} 类型错误,期望: {param_type}") + + # 枚举验证 + enum_values = param_info.get("enum") + if enum_values and value not in enum_values: + errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}") + + # 验证请求体参数 + request_body = operation.get("request_body") + if request_body: + schema = request_body.get("schema", {}) + required_props = schema.get("required", []) + properties = schema.get("properties", {}) + + for prop_name in required_props: + if prop_name not in params: + errors.append(f"缺少必需的请求体参数: {prop_name}") + + for prop_name, value in params.items(): + if prop_name in properties: + prop_schema = properties[prop_name] + prop_type = prop_schema.get("type", "string") + + if not self._validate_parameter_type(value, prop_type): + errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}") + + return len(errors) == 0, errors + + def _validate_parameter_type(self, value: Any, expected_type: str) -> bool: + """验证参数类型 + + Args: + value: 参数值 + expected_type: 期望类型 + + Returns: + 是否类型匹配 + """ + if value is None: + return True + + type_mapping = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict + } + + expected_python_type = type_mapping.get(expected_type) + if expected_python_type: + return isinstance(value, expected_python_type) + + return True \ No newline at end of file diff --git a/api/app/core/tools/executor.py b/api/app/core/tools/executor.py new file mode 100644 index 00000000..c0ba87fb --- /dev/null +++ b/api/app/core/tools/executor.py @@ -0,0 +1,501 @@ +"""工具执行器 - 负责工具的实际调用和执行管理""" +import asyncio +import uuid +import time +from typing import Dict, Any, List, Optional +from datetime import datetime +from sqlalchemy.orm import Session + +from app.models.tool_model import ToolExecution, ExecutionStatus +from app.core.tools.base import BaseTool, ToolResult +from app.core.tools.registry import ToolRegistry +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ExecutionContext: + """执行上下文""" + + def __init__( + self, + execution_id: str, + tool_id: str, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + timeout: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ): + self.execution_id = execution_id + self.tool_id = tool_id + self.user_id = user_id + self.workspace_id = workspace_id + self.timeout = timeout or 60.0 # 默认60秒超时 + self.metadata = metadata or {} + self.started_at = datetime.now() + self.completed_at: Optional[datetime] = None + self.status = ExecutionStatus.PENDING + + +class ToolExecutor: + """工具执行器 - 使用langchain标准接口执行工具""" + + def __init__(self, db: Session, registry: ToolRegistry): + """初始化工具执行器 + + Args: + db: 数据库会话 + registry: 工具注册表 + """ + self.db = db + self.registry = registry + self._running_executions: Dict[str, ExecutionContext] = {} + self._execution_lock = asyncio.Lock() + + async def execute_tool( + self, + tool_id: str, + parameters: Dict[str, Any], + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + execution_id: Optional[str] = None, + timeout: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> ToolResult: + """执行工具 + + Args: + tool_id: 工具ID + parameters: 工具参数 + user_id: 用户ID + workspace_id: 工作空间ID + execution_id: 执行ID(可选,自动生成) + timeout: 超时时间(秒) + metadata: 额外元数据 + + Returns: + 工具执行结果 + """ + # 生成执行ID + if not execution_id: + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + + # 创建执行上下文 + context = ExecutionContext( + execution_id=execution_id, + tool_id=tool_id, + user_id=user_id, + workspace_id=workspace_id, + timeout=timeout, + metadata=metadata + ) + + try: + # 获取工具实例 + tool = self.registry.get_tool(tool_id) + if not tool: + return ToolResult.error_result( + error=f"工具不存在: {tool_id}", + error_code="TOOL_NOT_FOUND", + execution_time=0.0 + ) + + # 记录执行开始 + await self._record_execution_start(context, parameters) + + # 执行工具 + result = await self._execute_with_timeout(tool, parameters, context) + + # 记录执行完成 + await self._record_execution_complete(context, result) + + return result + + except Exception as e: + logger.error(f"工具执行异常: {execution_id}, 错误: {e}") + + # 记录执行失败 + error_result = ToolResult.error_result( + error=str(e), + error_code="EXECUTION_ERROR", + execution_time=time.time() - context.started_at.timestamp() + ) + await self._record_execution_complete(context, error_result) + + return error_result + + finally: + # 清理执行上下文 + async with self._execution_lock: + if execution_id in self._running_executions: + del self._running_executions[execution_id] + + async def execute_tools_batch( + self, + tool_executions: List[Dict[str, Any]], + max_concurrency: int = 5 + ) -> List[ToolResult]: + """批量执行工具 + + Args: + tool_executions: 工具执行配置列表,每个包含tool_id和parameters + max_concurrency: 最大并发数 + + Returns: + 执行结果列表 + """ + semaphore = asyncio.Semaphore(max_concurrency) + + async def execute_single(exec_config: Dict[str, Any]) -> ToolResult: + async with semaphore: + return await self.execute_tool( + tool_id=exec_config["tool_id"], + parameters=exec_config.get("parameters", {}), + user_id=exec_config.get("user_id"), + workspace_id=exec_config.get("workspace_id"), + timeout=exec_config.get("timeout"), + metadata=exec_config.get("metadata") + ) + + # 并发执行所有工具 + tasks = [execute_single(config) for config in tool_executions] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理异常结果 + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + processed_results.append( + ToolResult.error_result( + error=str(result), + error_code="BATCH_EXECUTION_ERROR", + execution_time=0.0 + ) + ) + else: + processed_results.append(result) + + return processed_results + + async def cancel_execution(self, execution_id: str) -> bool: + """取消工具执行 + + Args: + execution_id: 执行ID + + Returns: + 是否成功取消 + """ + async with self._execution_lock: + if execution_id not in self._running_executions: + return False + + context = self._running_executions[execution_id] + context.status = ExecutionStatus.FAILED + + # 更新数据库记录 + execution_record = self.db.query(ToolExecution).filter( + ToolExecution.execution_id == execution_id + ).first() + + if execution_record: + execution_record.status = ExecutionStatus.FAILED.value + execution_record.error_message = "执行被取消" + execution_record.completed_at = datetime.now() + self.db.commit() + + logger.info(f"工具执行已取消: {execution_id}") + return True + + def get_running_executions(self) -> List[Dict[str, Any]]: + """获取正在运行的执行列表 + + Returns: + 执行信息列表 + """ + executions = [] + for execution_id, context in self._running_executions.items(): + executions.append({ + "execution_id": execution_id, + "tool_id": context.tool_id, + "user_id": str(context.user_id) if context.user_id else None, + "workspace_id": str(context.workspace_id) if context.workspace_id else None, + "started_at": context.started_at.isoformat(), + "status": context.status.value, + "elapsed_time": (datetime.now() - context.started_at).total_seconds() + }) + + return executions + + async def _execute_with_timeout( + self, + tool: BaseTool, + parameters: Dict[str, Any], + context: ExecutionContext + ) -> ToolResult: + """带超时的工具执行 + + Args: + tool: 工具实例 + parameters: 参数 + context: 执行上下文 + + Returns: + 执行结果 + """ + async with self._execution_lock: + self._running_executions[context.execution_id] = context + context.status = ExecutionStatus.RUNNING + + try: + # 使用asyncio.wait_for实现超时控制 + result = await asyncio.wait_for( + tool.safe_execute(**parameters), + timeout=context.timeout + ) + + context.status = ExecutionStatus.COMPLETED + return result + + except asyncio.TimeoutError: + context.status = ExecutionStatus.TIMEOUT + return ToolResult.error_result( + error=f"工具执行超时({context.timeout}秒)", + error_code="EXECUTION_TIMEOUT", + execution_time=context.timeout + ) + + except Exception as e: + context.status = ExecutionStatus.FAILED + raise + + async def _record_execution_start( + self, + context: ExecutionContext, + parameters: Dict[str, Any] + ): + """记录执行开始""" + try: + execution_record = ToolExecution( + execution_id=context.execution_id, + tool_config_id=uuid.UUID(context.tool_id), + status=ExecutionStatus.RUNNING.value, + input_data=parameters, + started_at=context.started_at, + user_id=context.user_id, + workspace_id=context.workspace_id + ) + + self.db.add(execution_record) + self.db.commit() + + logger.debug(f"执行记录已创建: {context.execution_id}") + + except Exception as e: + logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}") + + async def _record_execution_complete( + self, + context: ExecutionContext, + result: ToolResult + ): + """记录执行完成""" + try: + context.completed_at = datetime.now() + + execution_record = self.db.query(ToolExecution).filter( + ToolExecution.execution_id == context.execution_id + ).first() + + if execution_record: + execution_record.status = ( + ExecutionStatus.COMPLETED.value if result.success + else ExecutionStatus.FAILED.value + ) + execution_record.output_data = result.data if result.success else None + execution_record.error_message = result.error if not result.success else None + execution_record.completed_at = context.completed_at + execution_record.execution_time = result.execution_time + execution_record.token_usage = result.token_usage + + self.db.commit() + + logger.debug(f"执行记录已更新: {context.execution_id}") + + except Exception as e: + logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}") + + def get_execution_history( + self, + tool_id: Optional[str] = None, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """获取执行历史 + + Args: + tool_id: 工具ID过滤 + user_id: 用户ID过滤 + workspace_id: 工作空间ID过滤 + limit: 返回数量限制 + + Returns: + 执行历史列表 + """ + try: + query = self.db.query(ToolExecution).order_by( + ToolExecution.started_at.desc() + ) + + if tool_id: + query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id)) + + if user_id: + query = query.filter(ToolExecution.user_id == user_id) + + if workspace_id: + query = query.filter(ToolExecution.workspace_id == workspace_id) + + executions = query.limit(limit).all() + + history = [] + for execution in executions: + history.append({ + "execution_id": execution.execution_id, + "tool_id": str(execution.tool_config_id), + "status": execution.status, + "started_at": execution.started_at.isoformat() if execution.started_at else None, + "completed_at": execution.completed_at.isoformat() if execution.completed_at else None, + "execution_time": execution.execution_time, + "user_id": str(execution.user_id) if execution.user_id else None, + "workspace_id": str(execution.workspace_id) if execution.workspace_id else None, + "input_data": execution.input_data, + "output_data": execution.output_data, + "error_message": execution.error_message, + "token_usage": execution.token_usage + }) + + return history + + except Exception as e: + logger.error(f"获取执行历史失败, 错误: {e}") + return [] + + def get_execution_statistics( + self, + workspace_id: Optional[uuid.UUID] = None, + days: int = 7 + ) -> Dict[str, Any]: + """获取执行统计信息 + + Args: + workspace_id: 工作空间ID + days: 统计天数 + + Returns: + 统计信息 + """ + try: + from datetime import timedelta + + start_date = datetime.now() - timedelta(days=days) + + query = self.db.query(ToolExecution).filter( + ToolExecution.started_at >= start_date + ) + + if workspace_id: + query = query.filter(ToolExecution.workspace_id == workspace_id) + + executions = query.all() + + # 统计数据 + total_executions = len(executions) + successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value]) + failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value]) + + # 平均执行时间 + completed_executions = [e for e in executions if e.execution_time is not None] + avg_execution_time = ( + sum(e.execution_time for e in completed_executions) / len(completed_executions) + if completed_executions else 0 + ) + + # 按工具统计 + tool_stats = {} + for execution in executions: + tool_id = str(execution.tool_config_id) + if tool_id not in tool_stats: + tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0} + + tool_stats[tool_id]["total"] += 1 + if execution.status == ExecutionStatus.COMPLETED.value: + tool_stats[tool_id]["successful"] += 1 + elif execution.status == ExecutionStatus.FAILED.value: + tool_stats[tool_id]["failed"] += 1 + + return { + "period_days": days, + "total_executions": total_executions, + "successful_executions": successful_executions, + "failed_executions": failed_executions, + "success_rate": successful_executions / total_executions if total_executions > 0 else 0, + "average_execution_time": avg_execution_time, + "tool_statistics": tool_stats + } + + except Exception as e: + logger.error(f"获取执行统计失败, 错误: {e}") + return {} + + async def test_tool_connection( + self, + tool_id: str, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """测试工具连接""" + try: + from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig + from .mcp.client import MCPClient + + tool_config = self.db.query(ToolConfig).filter( + ToolConfig.id == uuid.UUID(tool_id) + ).first() + + if not tool_config: + return {"success": False, "message": "工具不存在"} + + if tool_config.tool_type == ToolType.MCP.value: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + + if not mcp_config: + return {"success": False, "message": "MCP配置不存在"} + + client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) + + if await client.connect(): + try: + tools = await client.list_tools() + await client.disconnect() + return { + "success": True, + "message": "MCP连接成功", + "details": {"server_url": mcp_config.server_url, "tools": len(tools)} + } + except: + await client.disconnect() + return {"success": False, "message": "MCP功能测试失败"} + else: + return {"success": False, "message": "MCP连接失败"} + else: + tool = self.registry.get_tool(tool_id) + if tool and hasattr(tool, 'test_connection'): + result = tool.test_connection() + return {"success": result.get("success", False), "message": result.get("message", "")} + return {"success": True, "message": "工具无需连接测试"} + except Exception as e: + return {"success": False, "message": "测试失败", "error": str(e)} \ No newline at end of file diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py new file mode 100644 index 00000000..1b6969b9 --- /dev/null +++ b/api/app/core/tools/langchain_adapter.py @@ -0,0 +1,375 @@ +"""Langchain适配器 - 将工具转换为langchain兼容格式""" +import json +from typing import Dict, Any, List, Optional, Type +from pydantic import BaseModel, Field +from langchain.tools import BaseTool as LangchainBaseTool +from langchain_core.tools import ToolException + +from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class LangchainToolWrapper(LangchainBaseTool): + """Langchain工具包装器""" + + name: str = Field(..., description="工具名称") + description: str = Field(..., description="工具描述") + args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema") + return_direct: bool = Field(False, description="是否直接返回结果") + + # 内部工具实例 + tool_instance: BaseTool = Field(..., description="内部工具实例") + + class Config: + arbitrary_types_allowed = True + + def __init__(self, tool_instance: BaseTool, **kwargs): + """初始化Langchain工具包装器 + + Args: + tool_instance: 内部工具实例 + """ + # 动态创建参数schema + args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters) + + super().__init__( + name=tool_instance.name, + description=tool_instance.description, + args_schema=args_schema, + _tool_instance=tool_instance, + **kwargs + ) + + def _run( + self, + run_manager=None, + **kwargs: Any, + ) -> str: + """同步执行工具(Langchain要求)""" + # 由于我们的工具是异步的,这里抛出异常提示使用异步版本 + raise NotImplementedError("请使用 _arun 方法进行异步调用") + + async def _arun( + self, + run_manager=None, + **kwargs: Any, + ) -> str: + """异步执行工具""" + try: + # 执行内部工具 + result = await self._tool_instance.safe_execute(**kwargs) + + # 转换结果为Langchain格式 + return LangchainAdapter._format_result_for_langchain(result) + + except Exception as e: + logger.error(f"工具执行失败: {self.name}, 错误: {e}") + raise ToolException(f"工具执行失败: {str(e)}") + + +class LangchainAdapter: + """Langchain适配器 - 负责工具格式转换和标准化""" + + @staticmethod + def convert_tool(tool: BaseTool) -> LangchainToolWrapper: + """将内部工具转换为Langchain工具 + + Args: + tool: 内部工具实例 + + Returns: + Langchain兼容的工具包装器 + """ + try: + wrapper = LangchainToolWrapper(tool_instance=tool) + logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") + return wrapper + + except Exception as e: + logger.error(f"工具转换失败: {tool.name}, 错误: {e}") + raise + + @staticmethod + def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: + """批量转换工具 + + Args: + tools: 工具列表 + + Returns: + Langchain工具列表 + """ + converted_tools = [] + + for tool in tools: + try: + converted_tool = LangchainAdapter.convert_tool(tool) + converted_tools.append(converted_tool) + except Exception as e: + logger.error(f"跳过工具转换: {tool.name}, 错误: {e}") + + logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具") + return converted_tools + + @staticmethod + def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]: + """根据工具参数创建Pydantic schema + + Args: + parameters: 工具参数列表 + + Returns: + Pydantic模型类 + """ + # 构建字段定义 + fields = {} + annotations = {} + + for param in parameters: + # 确定Python类型 + python_type = LangchainAdapter._get_python_type(param.type) + + # 处理可选参数 + if not param.required: + python_type = Optional[python_type] + + # 创建Field定义 + field_kwargs = { + "description": param.description + } + + if param.default is not None: + field_kwargs["default"] = param.default + elif not param.required: + field_kwargs["default"] = None + else: + field_kwargs["default"] = ... # 必需字段 + + # 添加验证约束 + if param.enum: + # 枚举值约束 + field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" + + if param.minimum is not None: + field_kwargs["ge"] = param.minimum + + if param.maximum is not None: + field_kwargs["le"] = param.maximum + + if param.pattern: + field_kwargs["regex"] = param.pattern + + fields[param.name] = Field(**field_kwargs) + annotations[param.name] = python_type + + # 动态创建Pydantic模型 + schema_class = type( + "ToolArgsSchema", + (BaseModel,), + { + "__annotations__": annotations, + **fields, + "Config": type("Config", (), {"extra": "forbid"}) + } + ) + + return schema_class + + @staticmethod + def _get_python_type(param_type: ParameterType) -> type: + """获取参数类型对应的Python类型 + + Args: + param_type: 参数类型 + + Returns: + Python类型 + """ + type_mapping = { + ParameterType.STRING: str, + ParameterType.INTEGER: int, + ParameterType.NUMBER: float, + ParameterType.BOOLEAN: bool, + ParameterType.ARRAY: list, + ParameterType.OBJECT: dict + } + + return type_mapping.get(param_type, str) + + @staticmethod + def _format_result_for_langchain(result: ToolResult) -> str: + """将工具结果格式化为Langchain标准格式 + + Args: + result: 工具执行结果 + + Returns: + 格式化的字符串结果 + """ + if not result.success: + # 错误结果 + error_info = { + "success": False, + "error": result.error, + "error_code": result.error_code, + "execution_time": result.execution_time + } + return json.dumps(error_info, ensure_ascii=False, indent=2) + + # 成功结果 + if isinstance(result.data, str): + # 如果数据已经是字符串,直接返回 + return result.data + elif isinstance(result.data, (dict, list)): + # 如果是结构化数据,转换为JSON + return json.dumps(result.data, ensure_ascii=False, indent=2) + else: + # 其他类型转换为字符串 + return str(result.data) + + @staticmethod + def create_tool_description(tool: BaseTool) -> Dict[str, Any]: + """创建工具描述(用于工具发现和文档生成) + + Args: + tool: 工具实例 + + Returns: + 工具描述字典 + """ + return { + "name": tool.name, + "description": tool.description, + "tool_type": tool.tool_type.value, + "version": tool.version, + "status": tool.status.value, + "tags": tool.tags, + "parameters": [ + { + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } + for param in tool.parameters + ], + "langchain_compatible": True + } + + @staticmethod + def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]: + """验证工具是否与Langchain兼容 + + Args: + tool: 工具实例 + + Returns: + (是否兼容, 问题列表) + """ + issues = [] + + # 检查工具名称 + if not tool.name or not isinstance(tool.name, str): + issues.append("工具名称必须是非空字符串") + + # 检查工具描述 + if not tool.description or not isinstance(tool.description, str): + issues.append("工具描述必须是非空字符串") + + # 检查参数定义 + for param in tool.parameters: + if not param.name or not isinstance(param.name, str): + issues.append(f"参数名称无效: {param.name}") + + if param.type not in ParameterType: + issues.append(f"不支持的参数类型: {param.type}") + + if param.required and param.default is not None: + issues.append(f"必需参数不应有默认值: {param.name}") + + # 检查是否有execute方法 + if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')): + issues.append("工具必须实现execute方法") + + return len(issues) == 0, issues + + @staticmethod + def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]: + """获取Langchain工具的OpenAPI schema + + Args: + tool: 工具实例 + + Returns: + OpenAPI schema字典 + """ + # 构建参数schema + properties = {} + required = [] + + for param in tool.parameters: + prop_schema = { + "type": LangchainAdapter._get_openapi_type(param.type), + "description": param.description + } + + if param.enum: + prop_schema["enum"] = param.enum + + if param.minimum is not None: + prop_schema["minimum"] = param.minimum + + if param.maximum is not None: + prop_schema["maximum"] = param.maximum + + if param.pattern: + prop_schema["pattern"] = param.pattern + + if param.default is not None: + prop_schema["default"] = param.default + + properties[param.name] = prop_schema + + if param.required: + required.append(param.name) + + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required + } + } + } + + @staticmethod + def _get_openapi_type(param_type: ParameterType) -> str: + """获取OpenAPI类型 + + Args: + param_type: 参数类型 + + Returns: + OpenAPI类型字符串 + """ + type_mapping = { + ParameterType.STRING: "string", + ParameterType.INTEGER: "integer", + ParameterType.NUMBER: "number", + ParameterType.BOOLEAN: "boolean", + ParameterType.ARRAY: "array", + ParameterType.OBJECT: "object" + } + + return type_mapping.get(param_type, "string") \ No newline at end of file diff --git a/api/app/core/tools/mcp/__init__.py b/api/app/core/tools/mcp/__init__.py new file mode 100644 index 00000000..faf13ceb --- /dev/null +++ b/api/app/core/tools/mcp/__init__.py @@ -0,0 +1,12 @@ +"""MCP工具模块""" + +from .base import MCPTool +from .client import MCPClient, MCPConnectionPool +from .service_manager import MCPServiceManager + +__all__ = [ + "MCPTool", + "MCPClient", + "MCPConnectionPool", + "MCPServiceManager" +] \ No newline at end of file diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py new file mode 100644 index 00000000..241069cd --- /dev/null +++ b/api/app/core/tools/mcp/base.py @@ -0,0 +1,258 @@ +"""MCP工具基类""" +import time +from typing import Dict, Any, List +import aiohttp + +from app.models.tool_model import ToolType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class MCPTool(BaseTool): + """MCP工具 - Model Context Protocol工具""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化MCP工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.server_url = config.get("server_url", "") + self.connection_config = config.get("connection_config", {}) + self.available_tools = config.get("available_tools", []) + self._client = None + self._connected = False + + @property + def name(self) -> str: + """工具名称""" + return f"mcp_tool_{self.tool_id[:8]}" + + @property + def description(self) -> str: + """工具描述""" + return f"MCP工具 - 连接到 {self.server_url}" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.MCP + + @property + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + params = [] + + # 添加工具选择参数 + if len(self.available_tools) > 1: + params.append(ToolParameter( + name="tool_name", + type=ParameterType.STRING, + description="要调用的MCP工具名称", + required=True, + enum=self.available_tools + )) + + # 添加通用参数 + params.extend([ + ToolParameter( + name="arguments", + type=ParameterType.OBJECT, + description="工具参数(JSON对象)", + required=False, + default={} + ), + ToolParameter( + name="timeout", + type=ParameterType.INTEGER, + description="超时时间(秒)", + required=False, + default=30, + minimum=1, + maximum=300 + ) + ]) + + return params + + async def execute(self, **kwargs) -> ToolResult: + """执行MCP工具""" + start_time = time.time() + + try: + # 确保连接 + if not self._connected: + await self.connect() + + # 确定要调用的工具 + tool_name = kwargs.get("tool_name") + if not tool_name and len(self.available_tools) == 1: + tool_name = self.available_tools[0] + + if not tool_name: + raise ValueError("必须指定要调用的MCP工具名称") + + if tool_name not in self.available_tools: + raise ValueError(f"MCP工具不存在: {tool_name}") + + # 获取参数 + arguments = kwargs.get("arguments", {}) + timeout = kwargs.get("timeout", 30) + + # 调用MCP工具 + result = await self._call_mcp_tool(tool_name, arguments, timeout) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="MCP_ERROR", + execution_time=execution_time + ) + + async def connect(self) -> bool: + """连接到MCP服务器""" + try: + # 这里应该实现实际的MCP连接逻辑 + # 为了简化,这里只是模拟连接 + + # 测试服务器连接 + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + # 尝试获取服务器信息 + async with session.get(f"{self.server_url}/info") as response: + if response.status == 200: + server_info = await response.json() + self.available_tools = server_info.get("tools", []) + self._connected = True + logger.info(f"MCP服务器连接成功: {self.server_url}") + return True + else: + raise Exception(f"服务器响应错误: {response.status}") + + except Exception as e: + logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}") + self._connected = False + return False + + async def disconnect(self) -> bool: + """断开MCP服务器连接""" + try: + if self._client: + # 这里应该实现实际的断开逻辑 + self._client = None + + self._connected = False + logger.info(f"MCP服务器连接已断开: {self.server_url}") + return True + + except Exception as e: + logger.error(f"断开MCP服务器连接失败: {e}") + return False + + def get_health_status(self) -> Dict[str, Any]: + """获取MCP服务健康状态""" + return { + "connected": self._connected, + "server_url": self.server_url, + "available_tools": self.available_tools, + "last_check": time.time() + } + + async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any: + """调用MCP工具""" + # 构建MCP请求 + request_data = { + "jsonrpc": "2.0", + "id": f"req_{int(time.time() * 1000)}", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + } + + # 发送请求 + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.post( + f"{self.server_url}/mcp", + json=request_data, + headers={"Content-Type": "application/json"} + ) as response: + + if response.status != 200: + error_text = await response.text() + raise Exception(f"MCP请求失败 {response.status}: {error_text}") + + result = await response.json() + + # 检查MCP响应 + if "error" in result: + error = result["error"] + raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}") + + return result.get("result", {}) + + async def list_available_tools(self) -> List[Dict[str, Any]]: + """列出可用的MCP工具""" + try: + if not self._connected: + await self.connect() + + # 获取工具列表 + request_data = { + "jsonrpc": "2.0", + "id": f"req_{int(time.time() * 1000)}", + "method": "tools/list" + } + + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{self.server_url}/mcp", + json=request_data, + headers={"Content-Type": "application/json"} + ) as response: + + if response.status == 200: + result = await response.json() + if "result" in result: + tools = result["result"].get("tools", []) + self.available_tools = [tool.get("name") for tool in tools] + return tools + + return [] + + except Exception as e: + logger.error(f"获取MCP工具列表失败: {e}") + return [] + + def test_connection(self) -> Dict[str, Any]: + """测试MCP连接""" + try: + # 这里应该实现同步的连接测试 + # 为了简化,返回基本信息 + return { + "success": bool(self.server_url), + "server_url": self.server_url, + "connected": self._connected, + "available_tools_count": len(self.available_tools), + "message": "MCP配置有效" if self.server_url else "缺少服务器URL配置" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py new file mode 100644 index 00000000..3be2e9bf --- /dev/null +++ b/api/app/core/tools/mcp/client.py @@ -0,0 +1,626 @@ +"""MCP客户端 - Model Context Protocol客户端实现""" +import asyncio +import json +import time +from typing import Dict, Any, List, Optional, Callable +from urllib.parse import urlparse +import aiohttp +import websockets +from websockets.exceptions import ConnectionClosed + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class MCPConnectionError(Exception): + """MCP连接错误""" + pass + + +class MCPProtocolError(Exception): + """MCP协议错误""" + pass + + +class MCPClient: + """MCP客户端 - 支持HTTP和WebSocket连接""" + + def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): + """初始化MCP客户端 + + Args: + server_url: MCP服务器URL + connection_config: 连接配置 + """ + self.server_url = server_url + self.connection_config = connection_config or {} + + # 解析URL确定连接类型 + parsed_url = urlparse(server_url) + self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http" + + # 连接状态 + self._connected = False + self._websocket = None + self._session = None + + # 请求管理 + self._request_id = 0 + self._pending_requests: Dict[str, asyncio.Future] = {} + + # 连接池配置 + self.max_connections = self.connection_config.get("max_connections", 10) + self.connection_timeout = self.connection_config.get("timeout", 30) + self.retry_attempts = self.connection_config.get("retry_attempts", 3) + self.retry_delay = self.connection_config.get("retry_delay", 1) + + # 健康检查 + self.health_check_interval = self.connection_config.get("health_check_interval", 60) + self._health_check_task = None + self._last_health_check = None + + # 事件回调 + self._on_connect_callbacks: List[Callable] = [] + self._on_disconnect_callbacks: List[Callable] = [] + self._on_error_callbacks: List[Callable] = [] + + async def connect(self) -> bool: + """连接到MCP服务器 + + Returns: + 连接是否成功 + """ + try: + if self._connected: + return True + + logger.info(f"连接MCP服务器: {self.server_url}") + + if self.connection_type == "websocket": + success = await self._connect_websocket() + else: + success = await self._connect_http() + + if success: + self._connected = True + await self._start_health_check() + await self._notify_connect_callbacks() + logger.info(f"MCP服务器连接成功: {self.server_url}") + + return success + + except Exception as e: + logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}") + await self._notify_error_callbacks(e) + return False + + async def disconnect(self) -> bool: + """断开MCP服务器连接 + + Returns: + 断开是否成功 + """ + try: + if not self._connected: + return True + + logger.info(f"断开MCP服务器连接: {self.server_url}") + + # 停止健康检查 + await self._stop_health_check() + + # 取消所有待处理的请求 + for future in self._pending_requests.values(): + if not future.done(): + future.cancel() + self._pending_requests.clear() + + # 断开连接 + if self.connection_type == "websocket" and self._websocket: + await self._websocket.close() + self._websocket = None + elif self._session: + await self._session.close() + self._session = None + + self._connected = False + await self._notify_disconnect_callbacks() + logger.info(f"MCP服务器连接已断开: {self.server_url}") + + return True + + except Exception as e: + logger.error(f"断开MCP服务器连接失败: {e}") + return False + + async def _connect_websocket(self) -> bool: + """建立WebSocket连接""" + try: + # WebSocket连接配置 + extra_headers = self.connection_config.get("headers", {}) + + self._websocket = await websockets.connect( + self.server_url, + extra_headers=extra_headers, + timeout=self.connection_timeout + ) + + # 启动消息监听 + asyncio.create_task(self._websocket_message_handler()) + + # 发送初始化消息 + init_message = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "clientInfo": { + "name": "ToolManagementSystem", + "version": "1.0.0" + } + } + } + + await self._websocket.send(json.dumps(init_message)) + + # 等待初始化响应 + response = await asyncio.wait_for( + self._websocket.recv(), + timeout=self.connection_timeout + ) + + init_response = json.loads(response) + if "error" in init_response: + raise MCPProtocolError(f"初始化失败: {init_response['error']}") + + return True + + except Exception as e: + logger.error(f"WebSocket连接失败: {e}") + return False + + async def _connect_http(self) -> bool: + """建立HTTP连接""" + try: + # HTTP会话配置 + timeout = aiohttp.ClientTimeout(total=self.connection_timeout) + headers = self.connection_config.get("headers", {}) + + self._session = aiohttp.ClientSession( + timeout=timeout, + headers=headers + ) + + # 测试连接 + test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health" + + async with self._session.get(test_url) as response: + if response.status == 200: + return True + else: + # 尝试根路径 + async with self._session.get(self.server_url) as root_response: + return root_response.status < 400 + + except Exception as e: + logger.error(f"HTTP连接失败: {e}") + if self._session: + await self._session.close() + self._session = None + return False + + async def _websocket_message_handler(self): + """WebSocket消息处理器""" + try: + while self._websocket and not self._websocket.closed: + try: + message = await self._websocket.recv() + await self._handle_message(json.loads(message)) + except ConnectionClosed: + break + except json.JSONDecodeError as e: + logger.error(f"解析WebSocket消息失败: {e}") + except Exception as e: + logger.error(f"处理WebSocket消息失败: {e}") + + except Exception as e: + logger.error(f"WebSocket消息处理器异常: {e}") + finally: + self._connected = False + await self._notify_disconnect_callbacks() + + async def _handle_message(self, message: Dict[str, Any]): + """处理收到的消息""" + try: + # 检查是否是响应消息 + if "id" in message: + request_id = str(message["id"]) + if request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(message) + + # 处理通知消息 + elif "method" in message: + await self._handle_notification(message) + + except Exception as e: + logger.error(f"处理消息失败: {e}") + + async def _handle_notification(self, message: Dict[str, Any]): + """处理通知消息""" + method = message.get("method") + params = message.get("params", {}) + + logger.debug(f"收到MCP通知: {method}, 参数: {params}") + + # 这里可以根据需要处理特定的通知 + # 例如:工具列表更新、服务器状态变化等 + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: + """调用MCP工具 + + Args: + tool_name: 工具名称 + arguments: 工具参数 + timeout: 超时时间(秒) + + Returns: + 工具执行结果 + + Raises: + MCPConnectionError: 连接错误 + MCPProtocolError: 协议错误 + """ + if not self._connected: + raise MCPConnectionError("MCP客户端未连接") + + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + } + + try: + response = await self._send_request(request_data, timeout) + + if "error" in response: + error = response["error"] + raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response.get("result", {}) + + except asyncio.TimeoutError: + raise MCPProtocolError(f"工具调用超时: {tool_name}") + + async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]: + """获取可用工具列表 + + Args: + timeout: 超时时间(秒) + + Returns: + 工具列表 + + Raises: + MCPConnectionError: 连接错误 + MCPProtocolError: 协议错误 + """ + if not self._connected: + raise MCPConnectionError("MCP客户端未连接") + + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "tools/list" + } + + try: + response = await self._send_request(request_data, timeout) + + if not response["error"] is None: + error = response["error"] + raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}") + + result = response.get("result", {}) + return result.get("tools", []) + + except asyncio.TimeoutError: + raise MCPProtocolError("获取工具列表超时") + + async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + """发送请求并等待响应 + + Args: + request_data: 请求数据 + timeout: 超时时间(秒) + + Returns: + 响应数据 + """ + request_id = str(request_data["id"]) + + if self.connection_type == "websocket": + return await self._send_websocket_request(request_data, request_id, timeout) + else: + return await self._send_http_request(request_data, timeout) + + async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]: + """发送WebSocket请求""" + if not self._websocket or self._websocket.closed: + raise MCPConnectionError("WebSocket连接已断开") + + # 创建Future等待响应 + future = asyncio.Future() + self._pending_requests[request_id] = future + + try: + # 发送请求 + await self._websocket.send(json.dumps(request_data)) + + # 等待响应 + response = await asyncio.wait_for(future, timeout=timeout) + return response + + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise + except Exception as e: + self._pending_requests.pop(request_id, None) + raise MCPConnectionError(f"发送WebSocket请求失败: {e}") + + async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + """发送HTTP请求""" + if not self._session: + raise MCPConnectionError("HTTP会话未建立") + + try: + url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp" + + async with self._session.post( + url, + json=request_data, + timeout=aiohttp.ClientTimeout(total=timeout) + ) as response: + + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") + + return await response.json() + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"HTTP请求失败: {e}") + + async def health_check(self) -> Dict[str, Any]: + """执行健康检查 + + Returns: + 健康状态信息 + """ + try: + if not self._connected: + return { + "healthy": False, + "error": "未连接", + "timestamp": time.time() + } + + # 发送ping请求 + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "ping" + } + + start_time = time.time() + response = await self._send_request(request_data, timeout=5) + response_time = time.time() - start_time + + self._last_health_check = time.time() + + return { + "healthy": True, + "response_time": response_time, + "timestamp": self._last_health_check, + "server_info": response.get("result", {}) + } + + except Exception as e: + return { + "healthy": False, + "error": str(e), + "timestamp": time.time() + } + + async def _start_health_check(self): + """启动健康检查任务""" + if self.health_check_interval > 0: + self._health_check_task = asyncio.create_task(self._health_check_loop()) + + async def _stop_health_check(self): + """停止健康检查任务""" + if self._health_check_task: + self._health_check_task.cancel() + try: + await self._health_check_task + except asyncio.CancelledError: + pass + self._health_check_task = None + + async def _health_check_loop(self): + """健康检查循环""" + try: + while self._connected: + await asyncio.sleep(self.health_check_interval) + + if self._connected: + health_status = await self.health_check() + if not health_status["healthy"]: + logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}") + # 可以在这里实现重连逻辑 + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"健康检查循环异常: {e}") + + def _get_next_request_id(self) -> str: + """获取下一个请求ID""" + self._request_id += 1 + return f"req_{self._request_id}_{int(time.time() * 1000)}" + + # 事件回调管理 + def on_connect(self, callback: Callable): + """注册连接回调""" + self._on_connect_callbacks.append(callback) + + def on_disconnect(self, callback: Callable): + """注册断开连接回调""" + self._on_disconnect_callbacks.append(callback) + + def on_error(self, callback: Callable): + """注册错误回调""" + self._on_error_callbacks.append(callback) + + async def _notify_connect_callbacks(self): + """通知连接回调""" + for callback in self._on_connect_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback() + else: + callback() + except Exception as e: + logger.error(f"连接回调执行失败: {e}") + + async def _notify_disconnect_callbacks(self): + """通知断开连接回调""" + for callback in self._on_disconnect_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback() + else: + callback() + except Exception as e: + logger.error(f"断开连接回调执行失败: {e}") + + async def _notify_error_callbacks(self, error: Exception): + """通知错误回调""" + for callback in self._on_error_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(error) + else: + callback(error) + except Exception as e: + logger.error(f"错误回调执行失败: {e}") + + @property + def is_connected(self) -> bool: + """检查是否已连接""" + return self._connected + + @property + def last_health_check(self) -> Optional[float]: + """获取最后一次健康检查时间""" + return self._last_health_check + + def get_connection_info(self) -> Dict[str, Any]: + """获取连接信息""" + return { + "server_url": self.server_url, + "connection_type": self.connection_type, + "connected": self._connected, + "last_health_check": self._last_health_check, + "pending_requests": len(self._pending_requests), + "config": self.connection_config + } + + async def __aenter__(self): + """异步上下文管理器入口""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.disconnect() + + +class MCPConnectionPool: + """MCP连接池 - 管理多个MCP客户端连接""" + + def __init__(self, max_connections: int = 10): + """初始化连接池 + + Args: + max_connections: 最大连接数 + """ + self.max_connections = max_connections + self._clients: Dict[str, MCPClient] = {} + self._lock = asyncio.Lock() + + async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient: + """获取或创建MCP客户端 + + Args: + server_url: 服务器URL + connection_config: 连接配置 + + Returns: + MCP客户端实例 + """ + async with self._lock: + if server_url in self._clients: + client = self._clients[server_url] + if client.is_connected: + return client + else: + # 尝试重连 + if await client.connect(): + return client + else: + # 移除失效的客户端 + del self._clients[server_url] + + # 检查连接数限制 + if len(self._clients) >= self.max_connections: + # 移除最旧的连接 + oldest_url = next(iter(self._clients)) + await self._clients[oldest_url].disconnect() + del self._clients[oldest_url] + + # 创建新客户端 + client = MCPClient(server_url, connection_config) + if await client.connect(): + self._clients[server_url] = client + return client + else: + raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}") + + async def disconnect_all(self): + """断开所有连接""" + async with self._lock: + for client in self._clients.values(): + await client.disconnect() + self._clients.clear() + + def get_pool_status(self) -> Dict[str, Any]: + """获取连接池状态""" + return { + "total_connections": len(self._clients), + "max_connections": self.max_connections, + "connections": { + url: client.get_connection_info() + for url, client in self._clients.items() + } + } \ No newline at end of file diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py new file mode 100644 index 00000000..53b83ddd --- /dev/null +++ b/api/app/core/tools/mcp/service_manager.py @@ -0,0 +1,604 @@ +"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控""" +import asyncio +import time +import uuid +from typing import Dict, Any, List, Optional, Tuple +from datetime import datetime +from sqlalchemy.orm import Session + +from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType +from app.core.logging_config import get_business_logger +from .client import MCPClient, MCPConnectionPool + +logger = get_business_logger() + + +class MCPServiceManager: + """MCP服务管理器 - 管理MCP服务的生命周期""" + + def __init__(self, db: Session): + """初始化MCP服务管理器 + + Args: + db: 数据库会话 + """ + self.db = db + self.connection_pool = MCPConnectionPool(max_connections=20) + + # 服务状态管理 + self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info + self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task + + # 配置 + self.health_check_interval = 60 # 健康检查间隔(秒) + self.max_retry_attempts = 3 # 最大重试次数 + self.retry_delay = 5 # 重试延迟(秒) + + # 状态 + self._running = False + self._manager_task = None + + async def start(self): + """启动服务管理器""" + if self._running: + return + + self._running = True + logger.info("MCP服务管理器启动") + + # 加载现有服务 + await self._load_existing_services() + + # 启动管理任务 + self._manager_task = asyncio.create_task(self._management_loop()) + + async def stop(self): + """停止服务管理器""" + if not self._running: + return + + self._running = False + logger.info("MCP服务管理器停止") + + # 停止管理任务 + if self._manager_task: + self._manager_task.cancel() + try: + await self._manager_task + except asyncio.CancelledError: + pass + + # 停止所有监控任务 + for task in self._monitoring_tasks.values(): + task.cancel() + + if self._monitoring_tasks: + await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True) + + self._monitoring_tasks.clear() + + # 断开所有连接 + await self.connection_pool.disconnect_all() + + async def register_service( + self, + server_url: str, + connection_config: Dict[str, Any], + tenant_id: uuid.UUID, + service_name: str = None + ) -> Tuple[bool, str, Optional[str]]: + """注册MCP服务 + + Args: + server_url: 服务器URL + connection_config: 连接配置 + tenant_id: 租户ID + service_name: 服务名称(可选) + + Returns: + (是否成功, 服务ID或错误信息, 错误详情) + """ + try: + # 检查服务是否已存在 + existing_service = self.db.query(MCPToolConfig).filter( + MCPToolConfig.server_url == server_url + ).first() + + if existing_service: + return False, "服务已存在", f"URL {server_url} 已被注册" + + # 测试连接 + try: + client = MCPClient(server_url, connection_config) + if not await client.connect(): + return False, "连接测试失败", "无法连接到MCP服务器" + + # 获取可用工具 + available_tools = await client.list_tools() + tool_names = [tool.get("name") for tool in available_tools if tool.get("name")] + + await client.disconnect() + + except Exception as e: + return False, "连接测试失败", str(e) + + # 创建工具配置 + if not service_name: + service_name = f"mcp_service_{server_url.split('/')[-1]}" + + tool_config = ToolConfig( + name=service_name, + description=f"MCP服务 - {server_url}", + tool_type=ToolType.MCP.value, + tenant_id=tenant_id, + version="1.0.0", + config_data={ + "server_url": server_url, + "connection_config": connection_config + } + ) + + self.db.add(tool_config) + self.db.flush() + + # 创建MCP特定配置 + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=server_url, + connection_config=connection_config, + available_tools=tool_names, + health_status="healthy", + last_health_check=datetime.utcnow() + ) + + self.db.add(mcp_config) + self.db.commit() + + service_id = str(tool_config.id) + + # 添加到内存管理 + self._services[service_id] = { + "id": service_id, + "server_url": server_url, + "connection_config": connection_config, + "tenant_id": tenant_id, + "available_tools": tool_names, + "status": "healthy", + "last_health_check": time.time(), + "retry_count": 0, + "created_at": time.time() + } + + # 启动监控 + await self._start_service_monitoring(service_id) + + logger.info(f"MCP服务注册成功: {service_id} ({server_url})") + return True, service_id, None + + except Exception as e: + self.db.rollback() + logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}") + return False, "注册失败", str(e) + + async def unregister_service(self, service_id: str) -> Tuple[bool, str]: + """注销MCP服务 + + Args: + service_id: 服务ID + + Returns: + (是否成功, 错误信息) + """ + try: + # 从数据库删除 + tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) + if not tool_config: + return False, "服务不存在" + + self.db.delete(tool_config) + self.db.commit() + + # 停止监控 + await self._stop_service_monitoring(service_id) + + # 从内存移除 + if service_id in self._services: + del self._services[service_id] + + logger.info(f"MCP服务注销成功: {service_id}") + return True, "" + + except Exception as e: + self.db.rollback() + logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}") + return False, str(e) + + async def update_service( + self, + service_id: str, + connection_config: Dict[str, Any] = None, + enabled: bool = None + ) -> Tuple[bool, str]: + """更新MCP服务配置 + + Args: + service_id: 服务ID + connection_config: 新的连接配置 + enabled: 是否启用 + + Returns: + (是否成功, 错误信息) + """ + try: + # 更新数据库 + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if not mcp_config: + return False, "服务不存在" + + tool_config = mcp_config.base_config + + if connection_config is not None: + mcp_config.connection_config = connection_config + tool_config.config_data["connection_config"] = connection_config + + if enabled is not None: + tool_config.is_enabled = enabled + + self.db.commit() + + # 更新内存状态 + if service_id in self._services: + if connection_config is not None: + self._services[service_id]["connection_config"] = connection_config + + # 如果配置有变化,重启监控 + if connection_config is not None: + await self._restart_service_monitoring(service_id) + + logger.info(f"MCP服务更新成功: {service_id}") + return True, "" + + except Exception as e: + self.db.rollback() + logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}") + return False, str(e) + + async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]: + """获取服务状态 + + Args: + service_id: 服务ID + + Returns: + 服务状态信息 + """ + if service_id not in self._services: + return None + + service_info = self._services[service_id].copy() + + # 添加实时健康检查 + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + health_status = await client.health_check() + service_info["real_time_health"] = health_status + + except Exception as e: + service_info["real_time_health"] = { + "healthy": False, + "error": str(e), + "timestamp": time.time() + } + + return service_info + + async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]: + """列出所有服务 + + Args: + tenant_id: 租户ID过滤 + + Returns: + 服务列表 + """ + services = [] + + for service_id, service_info in self._services.items(): + if tenant_id and service_info["tenant_id"] != tenant_id: + continue + + services.append(service_info.copy()) + + return services + + async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]: + """获取服务的可用工具 + + Args: + service_id: 服务ID + + Returns: + 工具列表 + """ + if service_id not in self._services: + return [] + + service_info = self._services[service_id] + + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + tools = await client.list_tools() + + # 更新缓存的工具列表 + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + service_info["available_tools"] = tool_names + + # 更新数据库 + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if mcp_config: + mcp_config.available_tools = tool_names + self.db.commit() + + return tools + + except Exception as e: + logger.error(f"获取服务工具失败: {service_id}, 错误: {e}") + return [] + + async def call_service_tool( + self, + service_id: str, + tool_name: str, + arguments: Dict[str, Any], + timeout: int = 30 + ) -> Dict[str, Any]: + """调用服务工具 + + Args: + service_id: 服务ID + tool_name: 工具名称 + arguments: 工具参数 + timeout: 超时时间 + + Returns: + 执行结果 + """ + if service_id not in self._services: + raise ValueError(f"服务不存在: {service_id}") + + service_info = self._services[service_id] + + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + result = await client.call_tool(tool_name, arguments, timeout) + + # 更新服务状态为健康 + service_info["status"] = "healthy" + service_info["last_health_check"] = time.time() + service_info["retry_count"] = 0 + + return result + + except Exception as e: + # 更新服务状态为错误 + service_info["status"] = "error" + service_info["last_error"] = str(e) + service_info["retry_count"] += 1 + + logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}") + raise + + async def _load_existing_services(self): + """加载现有服务""" + try: + mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter( + ToolConfig.is_enabled == True + ).all() + + for mcp_config in mcp_configs: + tool_config = mcp_config.base_config + service_id = str(mcp_config.id) + + self._services[service_id] = { + "id": service_id, + "server_url": mcp_config.server_url, + "connection_config": mcp_config.connection_config or {}, + "tenant_id": tool_config.tenant_id, + "available_tools": mcp_config.available_tools or [], + "status": mcp_config.health_status or "unknown", + "last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0, + "retry_count": 0, + "created_at": tool_config.created_at.timestamp() + } + + # 启动监控 + await self._start_service_monitoring(service_id) + + logger.info(f"加载了 {len(mcp_configs)} 个MCP服务") + + except Exception as e: + logger.error(f"加载现有服务失败: {e}") + + async def _start_service_monitoring(self, service_id: str): + """启动服务监控""" + if service_id in self._monitoring_tasks: + return + + task = asyncio.create_task(self._monitor_service(service_id)) + self._monitoring_tasks[service_id] = task + + async def _stop_service_monitoring(self, service_id: str): + """停止服务监控""" + if service_id in self._monitoring_tasks: + task = self._monitoring_tasks.pop(service_id) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def _restart_service_monitoring(self, service_id: str): + """重启服务监控""" + await self._stop_service_monitoring(service_id) + await self._start_service_monitoring(service_id) + + async def _monitor_service(self, service_id: str): + """监控单个服务""" + try: + while self._running and service_id in self._services: + service_info = self._services[service_id] + + try: + # 执行健康检查 + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + health_status = await client.health_check() + + if health_status["healthy"]: + # 服务健康 + service_info["status"] = "healthy" + service_info["retry_count"] = 0 + + # 更新工具列表 + try: + tools = await client.list_tools() + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + service_info["available_tools"] = tool_names + except Exception as e: + logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}") + + else: + # 服务不健康 + service_info["status"] = "unhealthy" + service_info["last_error"] = health_status.get("error", "健康检查失败") + service_info["retry_count"] += 1 + + service_info["last_health_check"] = time.time() + + # 更新数据库 + await self._update_service_health_in_db(service_id, health_status) + + except Exception as e: + # 监控异常 + service_info["status"] = "error" + service_info["last_error"] = str(e) + service_info["retry_count"] += 1 + service_info["last_health_check"] = time.time() + + logger.error(f"服务监控异常: {service_id}, 错误: {e}") + + # 如果重试次数过多,暂停监控 + if service_info["retry_count"] >= self.max_retry_attempts: + logger.warning(f"服务 {service_id} 重试次数过多,暂停监控") + await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间 + service_info["retry_count"] = 0 # 重置重试计数 + + # 等待下次检查 + await asyncio.sleep(self.health_check_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"服务监控任务异常: {service_id}, 错误: {e}") + + async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]): + """更新数据库中的服务健康状态""" + try: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if mcp_config: + mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy" + mcp_config.last_health_check = datetime.utcnow() + + if not health_status["healthy"]: + mcp_config.error_message = health_status.get("error", "") + else: + mcp_config.error_message = None + + self.db.commit() + + except Exception as e: + logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}") + self.db.rollback() + + async def _management_loop(self): + """管理循环 - 处理服务清理等任务""" + try: + while self._running: + # 清理失效的服务 + await self._cleanup_failed_services() + + # 等待下次循环 + await asyncio.sleep(300) # 5分钟 + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"管理循环异常: {e}") + + async def _cleanup_failed_services(self): + """清理长期失效的服务""" + try: + current_time = time.time() + cleanup_threshold = 24 * 60 * 60 # 24小时 + + services_to_cleanup = [] + + for service_id, service_info in self._services.items(): + # 检查服务是否长期失效 + if (service_info["status"] in ["error", "unhealthy"] and + current_time - service_info["last_health_check"] > cleanup_threshold): + + services_to_cleanup.append(service_id) + + for service_id in services_to_cleanup: + logger.warning(f"清理长期失效的服务: {service_id}") + + # 停止监控但不删除数据库记录 + await self._stop_service_monitoring(service_id) + + # 标记为禁用 + tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) + if tool_config: + tool_config.is_enabled = False + self.db.commit() + + # 从内存移除 + del self._services[service_id] + + except Exception as e: + logger.error(f"清理失效服务失败: {e}") + + def get_manager_status(self) -> Dict[str, Any]: + """获取管理器状态""" + return { + "running": self._running, + "total_services": len(self._services), + "healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]), + "unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]), + "monitoring_tasks": len(self._monitoring_tasks), + "connection_pool_status": self.connection_pool.get_pool_status() + } \ No newline at end of file diff --git a/api/app/core/tools/registry.py b/api/app/core/tools/registry.py new file mode 100644 index 00000000..b56c1bf7 --- /dev/null +++ b/api/app/core/tools/registry.py @@ -0,0 +1,436 @@ +"""工具注册表 - 管理所有工具的元数据和状态""" +import uuid +import asyncio +from typing import Dict, List, Optional, Type, Any + +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_ + +from app.models.tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolType, ToolStatus, ToolExecution, ExecutionStatus +) +from app.core.logging_config import get_business_logger +from .base import BaseTool, ToolInfo +from .custom.base import CustomTool +from .mcp.base import MCPTool + +logger = get_business_logger() + + +class ToolRegistry: + """工具注册表 - 管理所有工具的元数据和实例""" + + def __init__(self, db: Session): + """初始化工具注册表 + + Args: + db: 数据库会话 + """ + self.db = db + self._tools: Dict[str, BaseTool] = {} # 工具实例缓存 + self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表 + self._lock = asyncio.Lock() # 异步锁 + + def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None): + """注册工具类 + + Args: + tool_class: 工具类 + class_name: 类名(可选,默认使用类的__name__) + """ + class_name = class_name or tool_class.__name__ + self._tool_classes[class_name] = tool_class + logger.info(f"工具类已注册: {class_name}") + + async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool: + """注册工具实例到系统 + + Args: + tool: 工具实例 + tenant_id: 租户ID(内置工具可以为None,表示全局工具) + + Returns: + 注册是否成功 + """ + async with self._lock: + try: + # 检查工具是否已存在 + if tenant_id: + existing_config = self.db.query(ToolConfig).filter( + and_( + ToolConfig.name == tool.name, + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == tool.tool_type.value + ) + ).first() + else: + # 全局工具(内置工具) + existing_config = self.db.query(ToolConfig).filter( + and_( + ToolConfig.name == tool.name, + ToolConfig.tenant_id.is_(None), + ToolConfig.tool_type == tool.tool_type.value + ) + ).first() + + if existing_config: + logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})") + return False + + # 创建工具配置 + tool_config = ToolConfig( + name=tool.name, + description=tool.description, + tool_type=tool.tool_type.value, + tenant_id=tenant_id, + version=tool.version, + tags=tool.tags, + config_data=tool.config + ) + + self.db.add(tool_config) + self.db.flush() # 获取ID + + # 根据工具类型创建特定配置 + if tool.tool_type == ToolType.BUILTIN: + builtin_config = BuiltinToolConfig( + id=tool_config.id, + tool_class=tool.__class__.__name__, + parameters=tool.config.get("parameters", {}) + ) + self.db.add(builtin_config) + + elif tool.tool_type == ToolType.CUSTOM: + custom_config = CustomToolConfig( + id=tool_config.id, + schema_url=tool.config.get("schema_url"), + schema_content=tool.config.get("schema_content"), + auth_type=tool.config.get("auth_type", "none"), + auth_config=tool.config.get("auth_config", {}), + base_url=tool.config.get("base_url"), + timeout=tool.config.get("timeout", 30) + ) + self.db.add(custom_config) + + elif tool.tool_type == ToolType.MCP: + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=tool.config.get("server_url"), + connection_config=tool.config.get("connection_config", {}), + available_tools=tool.config.get("available_tools", []) + ) + self.db.add(mcp_config) + + self.db.commit() + + # 缓存工具实例 + tool.tool_id = str(tool_config.id) + self._tools[str(tool_config.id)] = tool + + logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具注册失败: {tool.name}, 错误: {e}") + return False + + async def unregister_tool(self, tool_id: str) -> bool: + """从系统注销工具 + + Args: + tool_id: 工具ID + + Returns: + 注销是否成功 + """ + async with self._lock: + try: + # 检查工具是否存在 + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config: + logger.warning(f"工具不存在: {tool_id}") + return False + + # 检查是否有正在执行的任务 + running_executions = self.db.query(ToolExecution).filter( + and_( + ToolExecution.tool_config_id == uuid.UUID(tool_id), + ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value]) + ) + ).count() + + if running_executions > 0: + logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}") + return False + + # 删除工具配置(级联删除相关记录) + self.db.delete(tool_config) + self.db.commit() + + # 从缓存中移除 + if tool_id in self._tools: + del self._tools[tool_id] + + logger.info(f"工具注销成功: {tool_id}") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具注销失败: {tool_id}, 错误: {e}") + return False + + def get_tool(self, tool_id: str) -> Optional[BaseTool]: + """获取工具实例 + + Args: + tool_id: 工具ID + + Returns: + 工具实例,如果不存在返回None + """ + # 先从缓存获取 + if tool_id in self._tools: + return self._tools[tool_id] + + # 从数据库加载 + try: + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value: + return None + + # 根据工具类型加载实例 + tool_instance = self._load_tool_instance(tool_config) + if tool_instance: + self._tools[tool_id] = tool_instance + return tool_instance + + except Exception as e: + logger.error(f"加载工具失败: {tool_id}, 错误: {e}") + + return None + + def list_tools( + self, + tenant_id: Optional[uuid.UUID] = None, + tool_type: Optional[ToolType] = None, + status: Optional[ToolStatus] = None, + tags: Optional[List[str]] = None + ) -> List[ToolInfo]: + """列出工具 + + Args: + tenant_id: 租户ID过滤 + tool_type: 工具类型过滤 + status: 工具状态过滤 + tags: 标签过滤 + + Returns: + 工具信息列表 + """ + try: + query = self.db.query(ToolConfig) + + # 应用过滤条件 + if tenant_id: + # 返回全局工具(tenant_id为空)和该租户的工具 + query = query.filter( + or_( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tenant_id.is_(None) + ) + ) + + if tool_type: + query = query.filter(ToolConfig.tool_type == tool_type.value) + + if status == ToolStatus.ACTIVE: + query = query.filter(ToolConfig.is_enabled == True) + elif status == ToolStatus.INACTIVE: + query = query.filter(ToolConfig.is_enabled == False) + + if tags: + for tag in tags: + query = query.filter(ToolConfig.tags.contains([tag])) + + tool_configs = query.all() + + # 转换为ToolInfo + tool_infos = [] + for config in tool_configs: + tool_info = ToolInfo( + id=str(config.id), + name=config.name, + description=config.description or "", + tool_type=ToolType(config.tool_type), + version=config.version, + status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE, + tags=config.tags or [], + tenant_id=str(config.tenant_id) if config.tenant_id else None + ) + + # 尝试获取参数信息 + tool_instance = self.get_tool(str(config.id)) + if tool_instance: + tool_info.parameters = tool_instance.parameters + + tool_infos.append(tool_info) + + return tool_infos + + except Exception as e: + logger.error(f"列出工具失败, 错误: {e}") + return [] + + async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool: + """更新工具状态 + + Args: + tool_id: 工具ID + status: 新状态 + + Returns: + 更新是否成功 + """ + try: + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config: + logger.warning(f"工具不存在: {tool_id}") + return False + + # 更新状态 + if status == ToolStatus.ACTIVE: + tool_config.is_enabled = True + elif status == ToolStatus.INACTIVE: + tool_config.is_enabled = False + + self.db.commit() + + # 更新缓存中的工具状态 + if tool_id in self._tools: + self._tools[tool_id].status = status + + logger.info(f"工具状态更新成功: {tool_id} -> {status}") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}") + return False + + def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]: + """从配置加载工具实例 + + Args: + tool_config: 工具配置 + + Returns: + 工具实例 + """ + try: + if tool_config.tool_type == ToolType.BUILTIN.value: + # 加载内置工具 + builtin_config = self.db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if builtin_config and builtin_config.tool_class in self._tool_classes: + tool_class = self._tool_classes[builtin_config.tool_class] + config = { + **tool_config.config_data, + "parameters": builtin_config.parameters, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return tool_class(str(tool_config.id), config) + + elif tool_config.tool_type == ToolType.CUSTOM.value: + # 加载自定义工具 + try: + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == tool_config.id + ).first() + + if custom_config: + config = { + **tool_config.config_data, + "schema_url": custom_config.schema_url, + "schema_content": custom_config.schema_content, + "auth_type": custom_config.auth_type, + "auth_config": custom_config.auth_config, + "base_url": custom_config.base_url, + "timeout": custom_config.timeout, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return CustomTool(str(tool_config.id), config) + except ImportError as e: + logger.error(f"无法导入自定义工具模块: {e}") + + elif tool_config.tool_type == ToolType.MCP.value: + # 加载MCP工具 + try: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + + if mcp_config: + config = { + **tool_config.config_data, + "server_url": mcp_config.server_url, + "connection_config": mcp_config.connection_config, + "available_tools": mcp_config.available_tools, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return MCPTool(str(tool_config.id), config) + except ImportError as e: + logger.error(f"无法导入MCP工具模块: {e}") + + except Exception as e: + logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}") + + return None + + def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: + """获取工具统计信息 + + Args: + tenant_id: 租户ID + + Returns: + 统计信息字典 + """ + try: + query = self.db.query(ToolConfig) + if tenant_id: + query = query.filter(ToolConfig.tenant_id == tenant_id) + + total_tools = query.count() + active_tools = query.filter(ToolConfig.is_enabled == True).count() + + # 按类型统计 + type_stats = {} + for tool_type in ToolType: + count = query.filter(ToolConfig.tool_type == tool_type.value).count() + type_stats[tool_type.value] = count + + return { + "total_tools": total_tools, + "active_tools": active_tools, + "inactive_tools": total_tools - active_tools, + "by_type": type_stats + } + + except Exception as e: + logger.error(f"获取工具统计失败, 错误: {e}") + return {} + + def clear_cache(self): + """清空工具缓存""" + self._tools.clear() + logger.info("工具缓存已清空") \ No newline at end of file diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index a945356a..04bc54dd 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -15,8 +15,13 @@ from langgraph.graph import StateGraph, START, END from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.expression_evaluator import evaluate_condition from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution +from app.core.tools.registry import ToolRegistry +from app.core.tools.executor import ToolExecutor +from app.core.tools.langchain_adapter import LangchainAdapter +TOOL_MANAGEMENT_AVAILABLE = True from app.db import get_db + logger = logging.getLogger(__name__) @@ -434,3 +439,180 @@ async def execute_workflow_stream( ) async for event in executor.execute_stream(input_data): yield event + + +# ==================== 工具管理系统集成 ==================== + +def get_workflow_tools(workspace_id: str, user_id: str) -> list: + """获取工作流可用的工具列表 + + Args: + workspace_id: 工作空间ID + user_id: 用户ID + + Returns: + 可用工具列表 + """ + if not TOOL_MANAGEMENT_AVAILABLE: + logger.warning("工具管理系统不可用") + return [] + + try: + from sqlalchemy.orm import Session + db = next(get_db()) + + # 创建工具注册表 + registry = ToolRegistry(db) + + # 注册内置工具类 + from app.core.tools.builtin import ( + DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool + ) + registry.register_tool_class(DateTimeTool) + registry.register_tool_class(JsonTool) + registry.register_tool_class(BaiduSearchTool) + registry.register_tool_class(MinerUTool) + registry.register_tool_class(TextInTool) + + # 获取活跃的工具 + import uuid + tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) + active_tools = [tool for tool in tools if tool.status.value == "active"] + + # 转换为Langchain工具 + langchain_tools = [] + for tool_info in active_tools: + try: + tool_instance = registry.get_tool(tool_info.id) + if tool_instance: + langchain_tool = LangchainAdapter.convert_tool(tool_instance) + langchain_tools.append(langchain_tool) + except Exception as e: + logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") + + logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") + return langchain_tools + + except Exception as e: + logger.error(f"获取工作流工具失败: {e}") + return [] + + +class ToolWorkflowNode: + """工具工作流节点 - 在工作流中执行工具""" + + def __init__(self, node_config: dict, workflow_config: dict): + """初始化工具节点 + + Args: + node_config: 节点配置 + workflow_config: 工作流配置 + """ + self.node_config = node_config + self.workflow_config = workflow_config + self.tool_id = node_config.get("tool_id") + self.tool_parameters = node_config.get("parameters", {}) + + async def run(self, state: WorkflowState) -> WorkflowState: + """执行工具节点""" + if not TOOL_MANAGEMENT_AVAILABLE: + logger.error("工具管理系统不可用") + state["error"] = "工具管理系统不可用" + return state + + try: + from sqlalchemy.orm import Session + db = next(get_db()) + + # 创建工具执行器 + registry = ToolRegistry(db) + executor = ToolExecutor(db, registry) + + # 准备参数(支持变量替换) + parameters = self._prepare_parameters(state) + + # 执行工具 + result = await executor.execute_tool( + tool_id=self.tool_id, + parameters=parameters, + user_id=uuid.UUID(state["user_id"]), + workspace_id=uuid.UUID(state["workspace_id"]) + ) + + # 更新状态 + node_id = self.node_config.get("id") + if result.success: + state["node_outputs"][node_id] = { + "type": "tool", + "tool_id": self.tool_id, + "output": result.data, + "execution_time": result.execution_time, + "token_usage": result.token_usage + } + + # 更新运行时变量 + if isinstance(result.data, dict): + for key, value in result.data.items(): + state["runtime_vars"][f"{node_id}.{key}"] = value + else: + state["runtime_vars"][f"{node_id}.result"] = result.data + else: + state["error"] = result.error + state["error_node"] = node_id + state["node_outputs"][node_id] = { + "type": "tool", + "tool_id": self.tool_id, + "error": result.error, + "execution_time": result.execution_time + } + + return state + + except Exception as e: + logger.error(f"工具节点执行失败: {e}") + state["error"] = str(e) + state["error_node"] = self.node_config.get("id") + return state + + def _prepare_parameters(self, state: WorkflowState) -> dict: + """准备工具参数(支持变量替换)""" + parameters = {} + + for key, value in self.tool_parameters.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + # 变量替换 + var_path = value[2:-1] + + # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} + if "." in var_path: + parts = var_path.split(".") + current = state.get("variables", {}) + + for part in parts: + if isinstance(current, dict) and part in current: + current = current[part] + else: + # 尝试从运行时变量获取 + runtime_key = ".".join(parts) + current = state.get("runtime_vars", {}).get(runtime_key, value) + break + + parameters[key] = current + else: + # 简单变量 + variables = state.get("variables", {}) + parameters[key] = variables.get(var_path, value) + else: + parameters[key] = value + + return parameters + + +# 注册工具节点到NodeFactory(如果存在) +try: + from app.core.workflow.nodes import NodeFactory + if hasattr(NodeFactory, 'register_node_type'): + NodeFactory.register_node_type("tool", ToolWorkflowNode) + logger.info("工具节点已注册到工作流系统") +except Exception as e: + logger.warning(f"注册工具节点失败: {e}") \ No newline at end of file diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index fc497215..09c88ba3 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -21,6 +21,10 @@ from .multi_agent_model import MultiAgentConfig, AgentInvocation from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .retrieval_info import RetrievalInfo from .prompt_optimizer_model import PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory +from .tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus +) __all__ = [ "Tenants", @@ -58,5 +62,15 @@ __all__ = [ "RetrievalInfo", "PromptOptimizerModelConfig", "PromptOptimizerSession", - "PromptOptimizerSessionHistory" + "PromptOptimizerSessionHistory", + "RetrievalInfo", + "ToolConfig", + "BuiltinToolConfig", + "CustomToolConfig", + "MCPToolConfig", + "ToolExecution", + "ToolType", + "ToolStatus", + "AuthType", + "ExecutionStatus" ] diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index fd3d9a31..552e87b5 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -21,3 +21,6 @@ class Tenants(Base): # Relationship to workspaces owned by the tenant owned_workspaces = relationship("Workspace", back_populates="tenant") + + # Relationship to tool configs owned by the tenant + tool_configs = relationship("ToolConfig", back_populates="tenant") diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py new file mode 100644 index 00000000..ac719317 --- /dev/null +++ b/api/app/models/tool_model.py @@ -0,0 +1,226 @@ +"""工具管理相关数据模型""" +import uuid +from datetime import datetime +from enum import StrEnum + +from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship + +from app.db import Base + + +class ToolType(StrEnum): + """工具类型枚举""" + BUILTIN = "builtin" + CUSTOM = "custom" + MCP = "mcp" + + +class ToolStatus(StrEnum): + """工具状态枚举""" + ACTIVE = "active" + INACTIVE = "inactive" + ERROR = "error" + LOADING = "loading" + + +class AuthType(StrEnum): + """认证类型枚举""" + NONE = "none" + API_KEY = "api_key" + BEARER_TOKEN = "bearer_token" + + +class ExecutionStatus(StrEnum): + """执行状态枚举""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + TIMEOUT = "timeout" + + +class ToolConfig(Base): + """工具配置基础模型""" + __tablename__ = "tool_configs" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(255), nullable=False, index=True) + description = Column(Text) + tool_type = Column(String(50), nullable=False, index=True) + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户 + status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态 + + # 工具特定配置(JSON格式存储) + config_data = Column(JSON, default=dict) + + # 元数据 + version = Column(String(50), default="1.0.0") + tags = Column(JSON, default=list) # 标签列表 + + # 时间戳 + created_at = Column(DateTime, default=datetime.now, nullable=False) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) + + # 关联关系 + tenant = relationship("Tenants", back_populates="tool_configs") + executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + +class BuiltinToolConfig(Base): + """内置工具配置模型""" + __tablename__ = "builtin_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + tool_class = Column(String(255), nullable=False) # 工具类名 + parameters = Column(JSON, default=dict) # 工具参数配置 + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class CustomToolConfig(Base): + """自定义工具配置模型""" + __tablename__ = "custom_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + schema_url = Column(String(1000)) # OpenAPI schema URL + schema_content = Column(JSON) # OpenAPI schema 内容 + + # 认证配置 + auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False) + auth_config = Column(JSON, default=dict) # 认证配置(加密存储) + + # API配置 + base_url = Column(String(1000)) # API基础URL + timeout = Column(Integer, default=30) # 超时时间(秒) + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class MCPToolConfig(Base): + """MCP工具配置模型""" + __tablename__ = "mcp_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + server_url = Column(String(1000), nullable=False) # MCP服务器URL + connection_config = Column(JSON, default=dict) # 连接配置 + + # 服务状态 + last_health_check = Column(DateTime) + health_status = Column(String(50), default="unknown") + error_message = Column(Text) + + # 可用工具列表 + available_tools = Column(JSON, default=list) + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class ToolExecution(Base): + """工具执行记录模型""" + __tablename__ = "tool_executions" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True) + + # 执行信息 + execution_id = Column(String(255), nullable=False, index=True) # 执行ID(可用于关联工作流等) + status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True) + + # 输入输出 + input_data = Column(JSON) # 输入参数 + output_data = Column(JSON) # 输出结果 + error_message = Column(Text) # 错误信息 + + # 性能指标 + started_at = Column(DateTime, nullable=False, index=True) + completed_at = Column(DateTime) + execution_time = Column(Float) # 执行时间(秒) + + # Token使用情况(如果适用) + token_usage = Column(JSON) + + # 用户信息 + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True) + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True) + + # 关联关系 + tool_config = relationship("ToolConfig", back_populates="executions") + user = relationship("User") + workspace = relationship("Workspace") + + def __repr__(self): + return f"" + + +# class ToolDependency(Base): +# """工具依赖关系模型""" +# __tablename__ = "tool_dependencies" +# +# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) +# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False) +# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False) +# +# # 依赖类型和版本要求 +# dependency_type = Column(String(50), default="required") # required, optional +# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0" +# +# # 时间戳 +# created_at = Column(DateTime, default=datetime.now, nullable=False) +# +# # 关联关系 +# tool = relationship("ToolConfig", foreign_keys=[tool_id]) +# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id]) +# +# def __repr__(self): +# return f"" + + +# class PluginConfig(Base): +# """插件配置模型""" +# __tablename__ = "plugin_configs" +# +# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) +# name = Column(String(255), nullable=False, unique=True, index=True) +# description = Column(Text) +# +# # 插件信息 +# plugin_path = Column(String(1000), nullable=False) # 插件文件路径 +# entry_point = Column(String(255), nullable=False) # 入口点 +# version = Column(String(50), default="1.0.0") +# +# # 状态 +# is_enabled = Column(Boolean, default=True, nullable=False) +# is_loaded = Column(Boolean, default=False, nullable=False) +# load_error = Column(Text) # 加载错误信息 +# +# # 配置 +# config_schema = Column(JSON) # 配置schema +# config_data = Column(JSON, default=dict) # 配置数据 +# +# # 依赖 +# dependencies = Column(JSON, default=list) # 依赖的其他插件 +# +# # 时间戳 +# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) +# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) +# last_loaded_at = Column(DateTime) +# +# def __repr__(self): +# return f"" \ No newline at end of file diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 4c011a87..7fe6a0c0 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -13,6 +13,11 @@ from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger from app.repositories import workspace_repository, knowledge_repository +from app.core.tools.registry import ToolRegistry +from app.core.tools.executor import ToolExecutor +from app.core.tools.langchain_adapter import LangchainAdapter +TOOL_MANAGEMENT_AVAILABLE = True + logger = get_business_logger() @@ -329,3 +334,216 @@ def create_agent_invocation_tool( return f"调用 Agent 失败: {str(e)}" return invoke_agent + +def get_available_tools_for_agent( + db: Session, + workspace_id: uuid.UUID, + agent_id: Optional[uuid.UUID] = None +) -> List[Dict[str, Any]]: + """获取Agent可用的工具列表 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + agent_id: Agent ID(可选) + + Returns: + 可用工具列表 + """ + if not TOOL_MANAGEMENT_AVAILABLE: + logger.warning("工具管理系统不可用") + return [] + + try: + # 创建工具注册表 + registry = ToolRegistry(db) + + # 获取工具列表 + tools = registry.list_tools(workspace_id=workspace_id) + + # 转换为Agent可用的格式 + available_tools = [] + for tool_info in tools: + if tool_info.status.value == "active": + available_tools.append({ + "id": tool_info.id, + "name": tool_info.name, + "description": tool_info.description, + "type": tool_info.tool_type.value, + "version": tool_info.version, + "tags": tool_info.tags, + "parameters": [ + { + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default + } + for param in tool_info.parameters + ] + }) + + logger.info(f"为Agent获取到 {len(available_tools)} 个可用工具") + return available_tools + + except Exception as e: + logger.error(f"获取Agent可用工具失败: {e}") + return [] + + +def create_langchain_tools_for_agent( + db: Session, + workspace_id: uuid.UUID, + agent_id: Optional[uuid.UUID] = None +) -> List[Any]: + """为Agent创建Langchain兼容的工具列表 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + agent_id: Agent ID(可选) + + Returns: + Langchain工具列表 + """ + if not TOOL_MANAGEMENT_AVAILABLE: + logger.warning("工具管理系统不可用") + return [] + + try: + # 创建工具注册表 + registry = ToolRegistry(db) + + # 注册内置工具类 + from app.core.tools.builtin import ( + DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool + ) + registry.register_tool_class(DateTimeTool) + registry.register_tool_class(JsonTool) + registry.register_tool_class(BaiduSearchTool) + registry.register_tool_class(MinerUTool) + registry.register_tool_class(TextInTool) + + # 获取活跃的工具 + tools = registry.list_tools(workspace_id=workspace_id) + active_tools = [tool for tool in tools if tool.status.value == "active"] + + # 转换为Langchain工具 + langchain_tools = [] + for tool_info in active_tools: + try: + tool_instance = registry.get_tool(tool_info.id) + if tool_instance: + langchain_tool = LangchainAdapter.convert_tool(tool_instance) + langchain_tools.append(langchain_tool) + except Exception as e: + logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") + + logger.info(f"为Agent创建了 {len(langchain_tools)} 个Langchain工具") + return langchain_tools + + except Exception as e: + logger.error(f"创建Agent Langchain工具失败: {e}") + return [] + + +class ToolExecutionInput(BaseModel): + """工具执行输入参数""" + tool_id: str = Field(..., description="工具ID") + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + timeout: Optional[float] = Field(None, description="超时时间(秒)") + + +def create_tool_execution_tool( + db: Session, + workspace_id: uuid.UUID, + user_id: uuid.UUID +): + """创建工具执行工具 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + user_id: 用户ID + + Returns: + 工具执行工具 + """ + if not TOOL_MANAGEMENT_AVAILABLE: + logger.warning("工具管理系统不可用") + return None + + @tool(args_schema=ToolExecutionInput) + async def execute_tool( + tool_id: str, + parameters: Dict[str, Any] = None, + timeout: Optional[float] = None + ) -> str: + """执行指定的工具。当需要使用系统中的工具来完成特定任务时使用。 + + Args: + tool_id: 工具ID(通过工具列表获取) + parameters: 工具参数(根据工具要求提供) + timeout: 超时时间(秒,可选) + + Returns: + 工具执行结果 + """ + try: + # 创建工具执行器 + registry = ToolRegistry(db) + executor = ToolExecutor(db, registry) + + # 执行工具 + result = await executor.execute_tool( + tool_id=tool_id, + parameters=parameters or {}, + user_id=user_id, + workspace_id=workspace_id, + timeout=timeout + ) + + if result.success: + # 格式化成功结果 + if isinstance(result.data, str): + return result.data + else: + import json + return json.dumps(result.data, ensure_ascii=False, indent=2) + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"工具执行异常: {tool_id}, 错误: {e}") + return f"工具执行异常: {str(e)}" + + return execute_tool + + +def get_tool_management_tools( + db: Session, + workspace_id: uuid.UUID, + user_id: uuid.UUID +) -> List[Any]: + """获取工具管理相关的工具 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + user_id: 用户ID + + Returns: + 工具管理工具列表 + """ + if not TOOL_MANAGEMENT_AVAILABLE: + return [] + + tools = [] + + # 添加工具执行工具 + execution_tool = create_tool_execution_tool(db, workspace_id, user_id) + if execution_tool: + tools.append(execution_tool) + + return tools \ No newline at end of file diff --git a/api/test_tool_system.py b/api/test_tool_system.py new file mode 100644 index 00000000..30d60d23 --- /dev/null +++ b/api/test_tool_system.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +工具管理系统基础测试脚本 +用于验证系统的基本功能是否正常 +""" + +import asyncio +import uuid +from datetime import datetime + +# 测试导入 +def test_imports(): + """测试模块导入""" + print("测试模块导入...") + + try: + from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType + print("✓ 基础工具模块导入成功") + except ImportError as e: + print(f"✗ 基础工具模块导入失败: {e}") + return False + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + from app.core.tools.builtin.json_tool import JsonTool + print("✓ 内置工具模块导入成功") + except ImportError as e: + print(f"✗ 内置工具模块导入失败: {e}") + return False + + try: + from app.core.tools.langchain_adapter import LangchainAdapter + print("✓ Langchain适配器导入成功") + except ImportError as e: + print(f"✗ Langchain适配器导入失败: {e}") + return False + + try: + from app.models.tool_model import ToolConfig, ToolType, ToolStatus + print("✓ 工具模型导入成功") + except ImportError as e: + print(f"✗ 工具模型导入失败: {e}") + return False + + try: + from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager + print("✓ 自定义工具模块导入成功") + except ImportError as e: + print(f"✗ 自定义工具模块导入失败: {e}") + return False + + try: + from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager + print("✓ MCP工具模块导入成功") + except ImportError as e: + print(f"✗ MCP工具模块导入失败: {e}") + return False + + return True + + +def test_tool_creation(): + """测试工具创建""" + print("\n测试工具创建...") + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + + # 创建时间工具实例(全局工具) + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"timezone": "UTC"}, + "tenant_id": None, # 全局工具 + "version": "1.0.0", + "tags": ["time", "utility", "builtin"] + } + + datetime_tool = DateTimeTool(tool_id, config) + + # 验证工具属性 + assert datetime_tool.name == "datetime_tool" + assert datetime_tool.tool_type.value == "builtin" + assert len(datetime_tool.parameters) > 0 + + print("✓ 时间工具创建成功(全局工具)") + return True + + except Exception as e: + print(f"✗ 工具创建失败: {e}") + return False + + +async def test_tool_execution(): + """测试工具执行""" + print("\n测试工具执行...") + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + + # 创建时间工具实例 + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"timezone": "UTC"}, + "tenant_id": None, # 全局工具 + "version": "1.0.0" + } + + datetime_tool = DateTimeTool(tool_id, config) + + # 测试获取当前时间 + result = await datetime_tool.safe_execute(operation="now") + + assert result.success == True + assert "datetime" in result.data + assert result.execution_time > 0 + + print("✓ 工具执行成功") + print(f" 执行时间: {result.execution_time:.3f}秒") + print(f" 返回数据: {result.data}") + + return True + + except Exception as e: + print(f"✗ 工具执行失败: {e}") + return False + + +def test_langchain_adapter(): + """测试Langchain适配器""" + print("\n测试Langchain适配器...") + + try: + from app.core.tools.builtin.json_tool import JsonTool + from app.core.tools.langchain_adapter import LangchainAdapter + + # 创建JSON工具实例 + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"indent": 2}, + "tenant_id": None, # 全局工具 + "version": "1.0.0" + } + + json_tool = JsonTool(tool_id, config) + + # 验证Langchain兼容性 + is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool) + + if not is_compatible: + print(f"✗ Langchain兼容性验证失败: {issues}") + return False + + # 创建工具描述 + description = LangchainAdapter.create_tool_description(json_tool) + + assert "name" in description + assert "parameters" in description + assert description["langchain_compatible"] == True + + print("✓ Langchain适配器测试成功") + return True + + except Exception as e: + print(f"✗ Langchain适配器测试失败: {e}") + return False + + +def test_config_manager(): + """测试配置管理器""" + print("\n测试配置管理器...") + + try: + from app.core.tools.config_manager import ConfigManager + + # 创建配置管理器 + config_manager = ConfigManager() + + # 获取配置摘要 + summary = config_manager.get_config_summary() + + assert "config_dir" in summary + assert "total_configs" in summary + + print("✓ 配置管理器测试成功") + print(f" 配置目录: {summary['config_dir']}") + print(f" 总配置数: {summary['total_configs']}") + + return True + + except Exception as e: + print(f"✗ 配置管理器测试失败: {e}") + return False + + +def test_schema_parser(): + """测试OpenAPI Schema解析器""" + print("\n测试OpenAPI Schema解析器...") + + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + + # 创建解析器 + parser = OpenAPISchemaParser() + + # 测试简单的OpenAPI schema + test_schema = { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + "description": "测试API" + }, + "paths": { + "/test": { + "get": { + "summary": "测试接口", + "operationId": "test_operation", + "responses": { + "200": { + "description": "成功" + } + } + } + } + } + } + + # 验证schema + is_valid, error_msg = parser.validate_schema(test_schema) + assert is_valid, f"Schema验证失败: {error_msg}" + + # 提取工具信息 + tool_info = parser.extract_tool_info(test_schema) + assert tool_info["name"] == "Test API" + assert "test_operation" in tool_info["operations"] + + print("✓ OpenAPI Schema解析器测试成功") + return True + + except Exception as e: + print(f"✗ OpenAPI Schema解析器测试失败: {e}") + return False + + +def test_auth_manager(): + """测试认证管理器""" + print("\n测试认证管理器...") + + try: + from app.core.tools.custom.auth_manager import AuthManager + from app.models.tool_model import AuthType + + # 创建认证管理器 + auth_manager = AuthManager() + + # 测试API Key认证配置 + api_key_config = { + "api_key": "test-key-123", + "key_name": "X-API-Key", + "location": "header" + } + + is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config) + assert is_valid, f"API Key配置验证失败: {error_msg}" + + # 测试Bearer Token认证配置 + bearer_config = { + "token": "bearer-token-123" + } + + is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config) + assert is_valid, f"Bearer Token配置验证失败: {error_msg}" + + # 测试认证应用 + url = "https://api.example.com/test" + headers = {} + params = {} + + new_url, new_headers, new_params = auth_manager.apply_authentication( + AuthType.API_KEY, api_key_config, url, headers, params + ) + + assert "X-API-Key" in new_headers + assert new_headers["X-API-Key"] == "test-key-123" + + print("✓ 认证管理器测试成功") + return True + + except Exception as e: + print(f"✗ 认证管理器测试失败: {e}") + return False + + +def test_builtin_initializer(): + """测试内置工具初始化器""" + print("\n测试内置工具初始化器...") + + try: + from app.core.tools.builtin_initializer import BuiltinToolInitializer + + # 注意:这里不能真正初始化,因为需要数据库连接 + # 只测试类的创建和基本方法 + + # 模拟数据库会话(实际使用中需要真实的数据库连接) + class MockDB: + def query(self, *args): + return self + def filter(self, *args): + return self + def first(self): + return None + def all(self): + return [] + + mock_db = MockDB() + initializer = BuiltinToolInitializer(mock_db) + + # 测试获取内置工具状态(会返回空列表,因为没有真实数据) + status = initializer.get_builtin_tools_status() + assert isinstance(status, list) + + print("✓ 内置工具初始化器测试成功") + return True + + except Exception as e: + print(f"✗ 内置工具初始化器测试失败: {e}") + return False + + +async def main(): + """主测试函数""" + print("=" * 50) + print("工具管理系统基础测试") + print("=" * 50) + + tests = [ + ("模块导入", test_imports), + ("工具创建", test_tool_creation), + ("工具执行", test_tool_execution), + ("Langchain适配", test_langchain_adapter), + ("配置管理", test_config_manager), + ("Schema解析器", test_schema_parser), + ("认证管理器", test_auth_manager), + ("内置工具初始化器", test_builtin_initializer) + ] + + passed = 0 + total = len(tests) + + for test_name, test_func in tests: + try: + if asyncio.iscoroutinefunction(test_func): + result = await test_func() + else: + result = test_func() + + if result: + passed += 1 + except Exception as e: + print(f"✗ {test_name}测试异常: {e}") + + print("\n" + "=" * 50) + print(f"测试结果: {passed}/{total} 通过") + + if passed == total: + print("🎉 所有基础测试通过!工具管理系统基本功能正常。") + return True + else: + print("⚠️ 部分测试失败,请检查相关模块。") + return False + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From f38c065f944b0328e0e6dfd7bb4dccb7b036fce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Thu, 18 Dec 2025 09:56:35 +0000 Subject: [PATCH 02/24] Merge #13 into develop from fix/stream-output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'fix/stream-output' * fix/stream-output: (17 commits squashed) - [fix]Fix the issue where the streaming output effect is not obvious. - [fix]Fix the issue where the streaming output effect is not obvious. - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix] - [fix]Skip time extraction - [fix] - [fix]Skip time extraction - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix]Remove human-induced delays - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Remove human-induced delays - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/13 --- .../extraction_orchestrator.py | 239 ++++++++++-------- 1 file changed, 138 insertions(+), 101 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 7eec1189..e00bcf0a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -179,8 +179,21 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成 - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成") + # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 + if self.progress_callback: + extraction_stats = { + "statements_count": total_statements, + "entities_count": 0, # 暂时为0,后续会更新 + "triplets_count": 0, # 暂时为0,后续会更新 + "temporal_ranges_count": 0, # 暂时为0,后续会更新 + } + await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) + + # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 + await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") + + # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") ( triplet_maps, temporal_maps, @@ -206,72 +219,6 @@ class ExtractionOrchestrator: logger.info("步骤 3/6: 生成实体嵌入") triplet_maps = await self._generate_entity_embeddings(triplet_maps) - # 进度回调:按三个阶段分别输出知识抽取结果 - if self.progress_callback: - # 第一阶段:陈述句提取结果 - for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句 - stmt_result = { - "extraction_type": "statement", - "statement_index": i + 1, - "statement": stmt.statement, - "statement_id": stmt.id - } - await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result) - - # 第二阶段:三元组提取结果 - for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组 - triplet_result = { - "extraction_type": "triplet", - "triplet_index": i + 1, - "subject": triplet.subject_name, - "predicate": triplet.predicate, - "object": triplet.object_name - } - await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result) - - # 第三阶段:时间提取结果 - if total_temporal > 0: - # 收集时间信息 - temporal_results = [] - for dialog in dialog_data_list: - for chunk in dialog.chunks: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - temporal_results.append({ - "statement_id": statement.id, - "statement": statement.statement, - "valid_at": statement.temporal_validity.valid_at, - "invalid_at": statement.temporal_validity.invalid_at - }) - - # 输出时间提取结果 - for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果 - time_result = { - "extraction_type": "temporal", - "temporal_index": i + 1, - "statement": temporal_result["statement"], - "valid_at": temporal_result["valid_at"], - "invalid_at": temporal_result["invalid_at"] - } - await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) - else: - # 如果没有时间信息,也发送一个时间提取完成的消息 - time_result = { - "extraction_type": "temporal", - "temporal_index": 0, - "message": "未发现时间信息" - } - await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) - - # 进度回调:知识抽取完成,传递知识抽取的统计信息 - extraction_stats = { - "statements_count": total_statements, - "entities_count": total_entities, - "triplets_count": total_triplets, - "temporal_ranges_count": total_temporal, - } - await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) - # 步骤 4: 将提取的数据赋值到语句 logger.info("步骤 4/6: 数据赋值") dialog_data_list = await self._assign_extracted_data( @@ -285,6 +232,9 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") + + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 + ( dialogue_nodes, chunk_nodes, @@ -304,6 +254,8 @@ class ExtractionOrchestrator: else: logger.info("步骤 6/6: 两阶段去重和消歧") + # 注意:deduplication 消息已在创建节点和边完成后立即发送 + result = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -328,7 +280,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ - 从对话中提取陈述句(优化版:全局分块级并行) + 从对话中提取陈述句(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -336,7 +288,7 @@ class ExtractionOrchestrator: Returns: 更新后的对话数据列表(包含提取的陈述句) """ - logger.info("开始陈述句提取(全局分块级并行)") + logger.info("开始陈述句提取(全局分块级并行 + 流式输出)") # 收集所有分块及其元数据 all_chunks = [] @@ -349,17 +301,44 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") + + # 用于跟踪已完成的分块数量 + completed_chunks = 0 + total_chunks = len(all_chunks) # 全局并行处理所有分块 - async def extract_for_chunk(chunk_data): + async def extract_for_chunk(chunk_data, chunk_index): + nonlocal completed_chunks chunk, group_id, dialogue_content = chunk_data try: - return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + + # 流式输出:每提取完一个分块的陈述句,立即发送进度 + # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 + completed_chunks += 1 + if self.progress_callback and statements and self.is_pilot_run: + # 发送前3个陈述句作为示例 + for idx, stmt in enumerate(statements[:3]): + stmt_result = { + "extraction_type": "statement", + "statement": stmt.statement, + "statement_id": stmt.id, + "chunk_progress": f"{completed_chunks}/{total_chunks}", + "statement_index_in_chunk": idx + 1 + } + await self.progress_callback( + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", + stmt_result + ) + + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") + completed_chunks += 1 return [] - tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks] + tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果分配回对话 @@ -391,7 +370,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取三元组(优化版:全局陈述句级并行) + 从对话中提取三元组(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -399,7 +378,7 @@ class ExtractionOrchestrator: Returns: 三元组映射列表,每个对话对应一个字典 """ - logger.info("开始三元组提取(全局陈述句级并行)") + logger.info("开始三元组提取(全局陈述句级并行 + 流式输出)") # 收集所有陈述句及其元数据 all_statements = [] @@ -412,18 +391,30 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") + + # 用于跟踪已完成的陈述句数量 + completed_statements = 0 + total_statements = len(all_statements) # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): + async def extract_for_statement(stmt_data, stmt_index): + nonlocal completed_statements statement, chunk_content = stmt_data try: - return await self.triplet_extractor._extract_triplets(statement, chunk_content) + triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) + + # 注意:不再发送三元组提取的流式输出 + # 三元组提取在后台执行,但不向前端发送详细信息 + completed_statements += 1 + + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") + completed_statements += 1 from app.core.memory.models.triplet_models import TripletExtractionResponse return TripletExtractionResponse(triplets=[], entities=[]) - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果组织成对话级别的映射 @@ -458,7 +449,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取时间信息(优化版:全局陈述句级并行) + 从对话中提取时间信息(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -466,7 +457,21 @@ class ExtractionOrchestrator: Returns: 时间信息映射列表,每个对话对应一个字典 """ - logger.info("开始时间信息提取(全局陈述句级并行)") + # 试运行模式:跳过时间提取以节省时间 + if self.is_pilot_run: + logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)") + # 为所有陈述句返回空的时间范围 + from app.core.memory.models.message_models import TemporalValidityRange + temporal_maps = [] + for dialog in dialog_data_list: + temporal_map = {} + for chunk in dialog.chunks: + for statement in chunk.statements: + temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) + temporal_maps.append(temporal_map) + return temporal_maps + + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] @@ -494,18 +499,30 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") + + # 用于跟踪已完成的时间提取数量 + completed_temporal = 0 + total_temporal_statements = len(all_statements) # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): + async def extract_for_statement(stmt_data, stmt_index): + nonlocal completed_temporal statement, ref_dates = stmt_data try: - return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) + temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) + + # 注意:不再发送时间提取的流式输出 + # 时间提取在后台执行,但不向前端发送详细信息 + completed_temporal += 1 + + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") + completed_temporal += 1 from app.core.memory.models.message_models import TemporalValidityRange return TemporalValidityRange(valid_at=None, invalid_at=None) - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果组织成对话级别的映射 @@ -832,9 +849,7 @@ class ExtractionOrchestrator: """ logger.info("开始创建节点和边") - # 进度回调:正在创建节点和边 - if self.progress_callback: - await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] chunk_nodes = [] @@ -846,8 +861,13 @@ class ExtractionOrchestrator: # 用于去重的集合 entity_id_set = set() + + # 用于跟踪进度 + total_dialogs = len(dialog_data_list) + processed_dialogs = 0 for dialog_data in dialog_data_list: + processed_dialogs += 1 # 创建对话节点 dialogue_node = DialogueNode( id=dialog_data.id, @@ -994,6 +1014,26 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) + + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) + if self.progress_callback and len(entity_entity_edges) <= 10: + # 获取实体名称 + source_name = triplet.subject_name + target_name = triplet.object_name + relationship_result = { + "result_type": "relationship_creation", + "relationship_index": len(entity_entity_edges), + "source_entity": source_name, + "relation_type": triplet.predicate, + "target_entity": target_name, + "relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}", + "dialog_progress": f"{processed_dialogs}/{total_dialogs}" + } + await self.progress_callback( + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", + relationship_result + ) else: logger.warning( f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " @@ -1008,12 +1048,9 @@ class ExtractionOrchestrator: f"实体-实体边: {len(entity_entity_edges)}" ) - # 进度回调:只输出关系创建结果 + # 进度回调:创建节点和边完成,传递结果统计 + # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: - # 输出关系创建结果 - await self._output_relationship_creation_results(entity_entity_edges, entity_nodes) - - # 进度回调:创建节点和边完成,传递结果统计 nodes_edges_stats = { "dialogue_nodes_count": len(dialogue_nodes), "chunk_nodes_count": len(chunk_nodes), @@ -1071,7 +1108,7 @@ class ExtractionOrchestrator: """ logger.info("开始两阶段实体去重和消歧") - # 进度回调:正在去重消歧 + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") @@ -1154,25 +1191,26 @@ class ExtractionOrchestrator: f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - # 进度回调:输出去重消歧的具体结果 + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: - # 分析实体合并情况 + # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - # 输出去重合并的实体示例 + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { "result_type": "entity_merge", "merged_entity_name": merge_detail["main_entity_name"], "merged_count": merge_detail["merged_count"], + "merge_progress": f"{i + 1}/{min(len(merge_info), 5)}", "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } - await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result) + await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - # 分析实体消歧情况 + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - # 输出实体消歧的结果 + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { "result_type": "entity_disambiguation", @@ -1180,11 +1218,10 @@ class ExtractionOrchestrator: "disambiguation_type": disamb_detail["disamb_type"], "confidence": disamb_detail.get("confidence", "unknown"), "reason": disamb_detail.get("reason", ""), + "disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}", "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } - await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result) - - + await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback( From 9e48f2143ee18df1923946d0d1078bc2499a082d Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Thu, 18 Dec 2025 18:51:32 +0800 Subject: [PATCH 03/24] [fix]document chunk QA --- api/app/core/rag/graphrag/utils.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/api/app/core/rag/graphrag/utils.py b/api/app/core/rag/graphrag/utils.py index 65beb31f..a2290516 100644 --- a/api/app/core/rag/graphrag/utils.py +++ b/api/app/core/rag/graphrag/utils.py @@ -1,12 +1,23 @@ import xxhash -from app.aioRedis import aio_redis_set, aio_redis_get +import redis +from app.core.config import settings + +redis_client = redis.StrictRedis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=30 +) + def get_llm_cache(llmnm, txt, history, genconf): hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) + hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8")) k = hasher.hexdigest() - bin = aio_redis_get(k) + bin = redis_client.get(k) if not bin: return None return bin @@ -14,6 +25,6 @@ def get_llm_cache(llmnm, txt, history, genconf): def set_llm_cache(llmnm, txt, v, history, genconf): hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) + hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8")) k = hasher.hexdigest() - aio_redis_set(k, v.encode("utf-8"), 24 * 3600) + redis_client.set(k, v.encode("utf-8"), 24 * 3600) From 3aff6baccbd04008b6360540fa4c65b58c6ca0a4 Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 18 Dec 2025 19:46:36 +0800 Subject: [PATCH 04/24] [add] workflow support stream mode --- api/app/controllers/app_controller.py | 19 +- api/app/core/workflow/executor.py | 250 +++++++++++---------- api/app/core/workflow/nodes/base_config.py | 5 + api/app/core/workflow/nodes/end/node.py | 1 - api/app/core/workflow/nodes/llm/node.py | 6 +- api/app/services/workflow_service.py | 145 +++++++++++- 6 files changed, 282 insertions(+), 144 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3d09f5fc..a92cfab2 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -421,8 +421,8 @@ async def draft_run( # 流式返回 if payload.stream: async def event_generator(): - - + + async for event in draft_service.run_stream( agent_config=agent_cfg, model_config=model_config, @@ -574,7 +574,7 @@ async def draft_run( # 3. 流式返回 if payload.stream: logger.debug( - "开始多智能体流式试运行", + "开始工作流流式试运行", extra={ "app_id": str(app_id), "message_length": len(payload.message), @@ -583,16 +583,13 @@ async def draft_run( ) async def event_generator(): - """多智能体流式事件生成器""" - multiservice = MultiAgentService(db) + """工作流事件生成器""" # 调用多智能体服务的流式方法 - async for event in multiservice.run_stream( + async for event in workflow_service.run_stream( app_id=app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - + payload=payload, + config=config ): yield event @@ -617,7 +614,7 @@ async def draft_run( ) result = await workflow_service.run(app_id, payload,config) - + logger.debug( "工作流试运行返回结果", extra={ diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 04bc54dd..9cf711db 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -11,26 +11,24 @@ from typing import Any from langchain_core.messages import HumanMessage from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.expression_evaluator import evaluate_condition -from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution from app.core.tools.registry import ToolRegistry from app.core.tools.executor import ToolExecutor from app.core.tools.langchain_adapter import LangchainAdapter TOOL_MANAGEMENT_AVAILABLE = True -from app.db import get_db - logger = logging.getLogger(__name__) class WorkflowExecutor: """工作流执行器 - + 负责将工作流配置转换为 LangGraph 并执行。 """ - + def __init__( self, workflow_config: dict[str, Any], @@ -39,7 +37,7 @@ class WorkflowExecutor: user_id: str ): """初始化执行器 - + Args: workflow_config: 工作流配置 execution_id: 执行 ID @@ -53,25 +51,25 @@ class WorkflowExecutor: self.nodes = workflow_config.get("nodes", []) self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) - + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: """准备初始状态(注入系统变量和会话变量) - + 变量命名空间: - sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等) - conv.xxx - 会话变量(跨多轮对话保持) - node_id.xxx - 节点输出(执行时动态生成) - + Args: input_data: 输入数据 - + Returns: 初始化的工作流状态 """ user_message = input_data.get("message") or "" conversation_vars = input_data.get("conversation_vars") or {} input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 - + # 构建分层的变量结构 variables = { "sys": { @@ -84,7 +82,7 @@ class WorkflowExecutor: }, "conv": conversation_vars # 会话级变量(跨多轮对话保持) } - + return { "messages": [HumanMessage(content=user_message)], "variables": variables, @@ -96,34 +94,34 @@ class WorkflowExecutor: "error": None, "error_node": None } - - - def build_graph(self) -> StateGraph: + + + def build_graph(self) -> CompiledStateGraph: """构建 LangGraph - + Returns: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") - + # 1. 创建状态图 workflow = StateGraph(WorkflowState) - + # 2. 添加所有节点(包括 start 和 end) start_node_id = None end_node_ids = [] - + for node in self.nodes: node_type = node.get("type") node_id = node.get("id") - + # 记录 start 和 end 节点 ID if node_type == "start": start_node_id = node_id elif node_type == "end": end_node_ids.append(node_id) - + # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) if node_instance: @@ -133,40 +131,40 @@ class WorkflowExecutor: async def node_func(state: WorkflowState): return await inst.run(state) return node_func - + workflow.add_node(node_id, make_node_func(node_instance)) logger.debug(f"添加节点: {node_id} (type={node_type})") - + # 3. 添加边 # 从 START 连接到 start 节点 if start_node_id: workflow.add_edge(START, start_node_id) logger.debug(f"添加边: START -> {start_node_id}") - + for edge in self.edges: source = edge.get("source") target = edge.get("target") edge_type = edge.get("type") condition = edge.get("condition") - + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) if source == start_node_id: # 但要连接 start 到下一个节点 workflow.add_edge(source, target) logger.debug(f"添加边: {source} -> {target}") continue - + # 处理到 end 节点的边 if target in end_node_ids: # 连接到 end 节点 workflow.add_edge(source, target) logger.debug(f"添加边: {source} -> {target}") continue - + # 跳过错误边(在节点内部处理) if edge_type == "error": continue - + if condition: # 条件边 def router(state: WorkflowState, cond=condition, tgt=target): @@ -183,74 +181,74 @@ class WorkflowExecutor: ): return tgt return END # 条件不满足,结束 - + workflow.add_conditional_edges(source, router) logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") else: # 普通边 workflow.add_edge(source, target) logger.debug(f"添加边: {source} -> {target}") - + # 从 end 节点连接到 END for end_node_id in end_node_ids: workflow.add_edge(end_node_id, END) logger.debug(f"添加边: {end_node_id} -> END") - + # 4. 编译图 graph = workflow.compile() logger.info(f"工作流图构建完成: execution_id={self.execution_id}") - + return graph - + async def execute( self, input_data: dict[str, Any] ) -> dict[str, Any]: """执行工作流(非流式) - + Args: input_data: 输入数据,包含 message 和 variables - + Returns: 执行结果,包含 status, output, node_outputs, elapsed_time, token_usage """ logger.info(f"开始执行工作流: execution_id={self.execution_id}") - + # 记录开始时间 start_time = datetime.datetime.now() - + # 1. 构建图 graph = self.build_graph() - + # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - + # 3. 执行工作流 try: result = await graph.ainvoke(initial_state) - + # 计算耗时 end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - + # 提取节点输出(现在包含 start 和 end 节点) node_outputs = result.get("node_outputs", {}) - + # 提取最终输出(从最后一个非 start/end 节点) final_output = self._extract_final_output(node_outputs) - + # 聚合 token 使用情况 token_usage = self._aggregate_token_usage(node_outputs) - + # 提取 conversation_id(从 start 节点输出) conversation_id = None for node_id, node_output in node_outputs.items(): if node_output.get("node_type") == "start": conversation_id = node_output.get("output", {}).get("conversation_id") break - + logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") - + return { "status": "completed", "output": final_output, @@ -261,12 +259,12 @@ class WorkflowExecutor: "token_usage": token_usage, "error": result.get("error") } - + except Exception as e: # 计算耗时(即使失败也记录) end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - + logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) return { "status": "failed", @@ -276,86 +274,94 @@ class WorkflowExecutor: "elapsed_time": elapsed_time, "token_usage": None } - + async def execute_stream( self, input_data: dict[str, Any] ): """执行工作流(流式) - + + 手动执行节点以支持细粒度的流式输出: + - workflow_start: 工作流开始 + - node_start: 节点开始执行 + - node_chunk: LLM 节点的流式输出片段(逐 token) + - node_complete: 节点执行完成 + - workflow_complete: 工作流完成 + Args: input_data: 输入数据 - + Yields: 流式事件 """ - logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") - + # + logger.info(f"开始执行工作流: execution_id={self.execution_id}") + + # 记录开始时间 + start_time = datetime.datetime.now() + # 1. 构建图 graph = self.build_graph() - + # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - - # 3. 流式执行工作流 + + # 3. 执行工作流 try: - # 使用 astream 获取节点级别的更新 - async for event in graph.astream(initial_state, stream_mode="updates"): - for node_name, state_update in event.items(): - yield { - "type": "node_complete", - "node": node_name, - "data": state_update, - "execution_id": self.execution_id - } - - logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}") - - # 发送完成事件 - yield { - "type": "workflow_complete", - "execution_id": self.execution_id - } - + async for chunk in graph.astream( + initial_state, + # subgraphs=True, + stream_mode="updates", + ): + # print(chunk) + yield chunk + except Exception as e: - logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True) + # 计算耗时(即使失败也记录) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) yield { - "type": "workflow_error", - "execution_id": self.execution_id, - "error": str(e) + "status": "failed", + "error": str(e), + "output": None, + "node_outputs": {}, + "elapsed_time": elapsed_time, + "token_usage": None } - + def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: """从节点输出中提取最终输出 - + 优先级: 1. 最后一个执行的非 start/end 节点的 output 2. 如果没有节点输出,返回 None - + Args: node_outputs: 所有节点的输出 - + Returns: 最终输出字符串或 None """ if not node_outputs: return None - + # 获取最后一个节点的输出 last_node_output = list(node_outputs.values())[-1] if node_outputs else None - + if last_node_output and isinstance(last_node_output, dict): return last_node_output.get("output") - + return None - + def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None: """聚合所有节点的 token 使用情况 - + Args: node_outputs: 所有节点的输出 - + Returns: 聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z} 如果没有 token 使用信息,返回 None @@ -364,7 +370,7 @@ class WorkflowExecutor: total_completion_tokens = 0 total_tokens = 0 has_token_info = False - + for node_output in node_outputs.values(): if isinstance(node_output, dict): token_usage = node_output.get("token_usage") @@ -373,16 +379,16 @@ class WorkflowExecutor: total_prompt_tokens += token_usage.get("prompt_tokens", 0) total_completion_tokens += token_usage.get("completion_tokens", 0) total_tokens += token_usage.get("total_tokens", 0) - + if not has_token_info: return None - + return { "prompt_tokens": total_prompt_tokens, "completion_tokens": total_completion_tokens, "total_tokens": total_tokens } - + async def execute_workflow( workflow_config: dict[str, Any], @@ -392,14 +398,14 @@ async def execute_workflow( user_id: str ) -> dict[str, Any]: """执行工作流(便捷函数) - + Args: workflow_config: 工作流配置 input_data: 输入数据 execution_id: 执行 ID workspace_id: 工作空间 ID user_id: 用户 ID - + Returns: 执行结果 """ @@ -420,14 +426,14 @@ async def execute_workflow_stream( user_id: str ): """执行工作流(流式,便捷函数) - + Args: workflow_config: 工作流配置 input_data: 输入数据 execution_id: 执行 ID workspace_id: 工作空间 ID user_id: 用户 ID - + Yields: 流式事件 """ @@ -445,25 +451,25 @@ async def execute_workflow_stream( def get_workflow_tools(workspace_id: str, user_id: str) -> list: """获取工作流可用的工具列表 - + Args: workspace_id: 工作空间ID user_id: 用户ID - + Returns: 可用工具列表 """ if not TOOL_MANAGEMENT_AVAILABLE: logger.warning("工具管理系统不可用") return [] - + try: from sqlalchemy.orm import Session db = next(get_db()) - + # 创建工具注册表 registry = ToolRegistry(db) - + # 注册内置工具类 from app.core.tools.builtin import ( DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool @@ -473,12 +479,12 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list: registry.register_tool_class(BaiduSearchTool) registry.register_tool_class(MinerUTool) registry.register_tool_class(TextInTool) - + # 获取活跃的工具 import uuid tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) active_tools = [tool for tool in tools if tool.status.value == "active"] - + # 转换为Langchain工具 langchain_tools = [] for tool_info in active_tools: @@ -489,10 +495,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list: langchain_tools.append(langchain_tool) except Exception as e: logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") - + logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") return langchain_tools - + except Exception as e: logger.error(f"获取工作流工具失败: {e}") return [] @@ -500,10 +506,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list: class ToolWorkflowNode: """工具工作流节点 - 在工作流中执行工具""" - + def __init__(self, node_config: dict, workflow_config: dict): """初始化工具节点 - + Args: node_config: 节点配置 workflow_config: 工作流配置 @@ -512,25 +518,25 @@ class ToolWorkflowNode: self.workflow_config = workflow_config self.tool_id = node_config.get("tool_id") self.tool_parameters = node_config.get("parameters", {}) - + async def run(self, state: WorkflowState) -> WorkflowState: """执行工具节点""" if not TOOL_MANAGEMENT_AVAILABLE: logger.error("工具管理系统不可用") state["error"] = "工具管理系统不可用" return state - + try: from sqlalchemy.orm import Session db = next(get_db()) - + # 创建工具执行器 registry = ToolRegistry(db) executor = ToolExecutor(db, registry) - + # 准备参数(支持变量替换) parameters = self._prepare_parameters(state) - + # 执行工具 result = await executor.execute_tool( tool_id=self.tool_id, @@ -538,7 +544,7 @@ class ToolWorkflowNode: user_id=uuid.UUID(state["user_id"]), workspace_id=uuid.UUID(state["workspace_id"]) ) - + # 更新状态 node_id = self.node_config.get("id") if result.success: @@ -549,7 +555,7 @@ class ToolWorkflowNode: "execution_time": result.execution_time, "token_usage": result.token_usage } - + # 更新运行时变量 if isinstance(result.data, dict): for key, value in result.data.items(): @@ -565,29 +571,29 @@ class ToolWorkflowNode: "error": result.error, "execution_time": result.execution_time } - + return state - + except Exception as e: logger.error(f"工具节点执行失败: {e}") state["error"] = str(e) state["error_node"] = self.node_config.get("id") return state - + def _prepare_parameters(self, state: WorkflowState) -> dict: """准备工具参数(支持变量替换)""" parameters = {} - + for key, value in self.tool_parameters.items(): if isinstance(value, str) and value.startswith("${") and value.endswith("}"): # 变量替换 var_path = value[2:-1] - + # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} if "." in var_path: parts = var_path.split(".") current = state.get("variables", {}) - + for part in parts: if isinstance(current, dict) and part in current: current = current[part] @@ -596,7 +602,7 @@ class ToolWorkflowNode: runtime_key = ".".join(parts) current = state.get("runtime_vars", {}).get(runtime_key, value) break - + parameters[key] = current else: # 简单变量 @@ -604,7 +610,7 @@ class ToolWorkflowNode: parameters[key] = variables.get(var_path, value) else: parameters[key] = value - + return parameters diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 8423f479..90d02732 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -50,6 +50,11 @@ class VariableDefinition(BaseModel): description="变量描述" ) + max_length: int = Field( + default=200, + description="只对字符串类型生效" + ) + class Config: json_schema_extra = { "examples": [ diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 1c0e6747..ad028f31 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -5,7 +5,6 @@ End 节点实现 """ import logging -from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bfc7da58..cf665ff1 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig -from app.models import ModelConfig -from app.db import get_db, get_db_context -from app.models.models_model import ModelApiKey -from app.services.model_service import ModelConfigService, ModelApiKeyService +from app.db import get_db_context +from app.services.model_service import ModelConfigService from app.core.exceptions import BusinessException from app.core.error_codes import BizCode diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index c604697b..f0b71824 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -1,7 +1,7 @@ """ 工作流服务层 """ - +import json import logging import uuid import datetime @@ -438,7 +438,7 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} - + # 转换 user_id 为 UUID triggered_by_uuid = None if payload.user_id: @@ -446,7 +446,7 @@ class WorkflowService: triggered_by_uuid = uuid.UUID(payload.user_id) except (ValueError, AttributeError): logger.warning(f"无效的 user_id 格式: {payload.user_id}") - + # 转换 conversation_id 为 UUID conversation_id_uuid = None if payload.conversation_id: @@ -454,7 +454,7 @@ class WorkflowService: conversation_id_uuid = uuid.UUID(payload.conversation_id) except (ValueError, AttributeError): logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") - + # 2. 创建执行记录 execution = self.create_execution( workflow_config_id=config.id, @@ -530,6 +530,109 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) + async def run_stream( + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig + ): + """运行工作流(流式) + + Args: + app_id: 应用 ID + payload: 请求对象(包含 message, variables, conversation_id 等) + config: 存储类型(可选) + + Yields: + SSE 格式的流式事件 + + Raises: + BusinessException: 配置不存在或执行失败时抛出 + """ + # 1. 获取工作流配置 + if not config: + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + code=BizCode.CONFIG_MISSING, + message=f"工作流配置不存在: app_id={app_id}" + ) + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id} + + # 转换 user_id 为 UUID + triggered_by_uuid = None + if payload.user_id: + try: + triggered_by_uuid = uuid.UUID(payload.user_id) + except (ValueError, AttributeError): + logger.warning(f"无效的 user_id 格式: {payload.user_id}") + + # 转换 conversation_id 为 UUID + conversation_id_uuid = None + if payload.conversation_id: + try: + conversation_id_uuid = uuid.UUID(payload.conversation_id) + except (ValueError, AttributeError): + logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") + + # 2. 创建执行记录 + execution = self.create_execution( + workflow_config_id=config.id, + app_id=app_id, + trigger_type="manual", + triggered_by=triggered_by_uuid, + conversation_id=conversation_id_uuid, + input_data=input_data + ) + + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + from app.models import App + + # 5. 流式执行工作流 + from app.core.workflow.executor import execute_workflow, execute_workflow_stream + + try: + # 更新状态为运行中 + self.update_execution_status(execution.execution_id, "running") + + # 发送开始事件 + yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n" + + # 调用流式执行 + async for event in self._run_workflow_stream( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id="", + user_id=payload.user_id + ): + # 清理事件数据,移除不可序列化的对象 + cleaned_event = self._clean_event_for_json(event) + # 转换为 SSE 格式 + yield f"data: {json.dumps(cleaned_event)}\n\n" + + # 发送完成事件 + yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n" + + except Exception as e: + logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + self.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) + # 发送错误事件 + yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n" + async def run_workflow( self, app_id: uuid.UUID, @@ -651,14 +754,44 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) + def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: + """清理事件数据,移除不可序列化的对象 + + Args: + event: 原始事件数据 + + Returns: + 可序列化的事件数据 + """ + from langchain_core.messages import BaseMessage + + def clean_value(value): + """递归清理值""" + if isinstance(value, BaseMessage): + # 将 Message 对象转换为字典 + return { + "type": value.__class__.__name__, + "content": value.content, + } + elif isinstance(value, dict): + return {k: clean_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [clean_value(item) for item in value] + elif isinstance(value, (str, int, float, bool, type(None))): + return value + else: + # 其他不可序列化的对象转换为字符串 + return str(value) + + return clean_value(event) + async def _run_workflow_stream( self, workflow_config: dict[str, Any], input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str - ): + user_id: str): """运行工作流(流式,内部方法) Args: From 240e94cb38fa375a7f602a01b9e25cf51321c372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Fri, 19 Dec 2025 08:04:12 +0000 Subject: [PATCH 05/24] Merge #9 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) * fix/memory_reflection: (24 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 Signed-off-by: aliyun8644380055 Commented-by: aliyun8644380055 Commented-by: aliyun6762716068 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/9 --- api/app/celery_app.py | 19 +- api/app/controllers/__init__.py | 2 + .../memory_reflection_controller.py | 200 +++++++++ api/app/core/config.py | 1 + .../reflection_engine/example/example.json | 210 +++++++++ .../reflection_engine/self_reflexion.py | 322 ++++++++------ api/app/core/memory/utils/config/get_data.py | 62 +-- .../utils/prompt/prompts/evaluate.jinja2 | 221 +++++++++- .../utils/prompt/prompts/reflexion.jinja2 | 307 +++++++++++++- .../memory/utils/prompt/template_render.py | 28 +- api/app/models/data_config_model.py | 26 +- api/app/models/end_user_model.py | 1 + .../repositories/data_config_repository.py | 252 +++++++---- api/app/repositories/neo4j/cypher_queries.py | 54 +++ api/app/repositories/neo4j/neo4j_update.py | 227 ++++++++++ api/app/schemas/end_user_schema.py | 1 + api/app/schemas/memory_reflection_schemas.py | 54 +++ api/app/schemas/memory_storage_schema.py | 63 ++- api/app/services/memory_reflection_service.py | 397 ++++++++++++++++++ api/app/tasks.py | 163 ++++++- api/check_code.py | 108 +++++ 21 files changed, 2383 insertions(+), 335 deletions(-) create mode 100644 api/app/controllers/memory_reflection_controller.py create mode 100644 api/app/core/memory/storage_services/reflection_engine/example/example.json create mode 100644 api/app/repositories/neo4j/neo4j_update.py create mode 100644 api/app/schemas/memory_reflection_schemas.py create mode 100644 api/app/services/memory_reflection_service.py create mode 100755 api/check_code.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index d072a346..ce7e9300 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -83,17 +83,18 @@ celery_app.autodiscover_tasks(['app']) reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) - +workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME # 构建定时任务配置 beat_schedule_config = { - "run-reflection-engine": { - "task": "app.core.memory.agent.reflection.timer", - "schedule": reflection_schedule, - "args": (), - }, - "check-read-service": { - "task": "app.core.memory.agent.health.check_read_service", - "schedule": health_schedule, + + # "check-read-service": { + # "task": "app.core.memory.agent.health.check_read_service", + # "schedule": health_schedule, + # "args": (), + # }, + "run-workspace-reflection": { + "task": "app.tasks.workspace_reflection_task", + "schedule": workspace_reflection_schedule, "args": (), }, } diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index fe7c692e..47cc8688 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -23,6 +23,7 @@ from . import ( memory_dashboard_controller, memory_storage_controller, memory_dashboard_controller, + memory_reflection_controller, api_key_controller, release_share_controller, public_share_controller, @@ -62,6 +63,7 @@ manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) manager_router.include_router(workflow_controller.router) manager_router.include_router(prompt_optimizer_controller.router) +manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(tool_controller.router) manager_router.include_router(tool_execution_controller.router) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py new file mode 100644 index 00000000..759c25c5 --- /dev/null +++ b/api/app/controllers/memory_reflection_controller.py @@ -0,0 +1,200 @@ +import asyncio + +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from sqlalchemy import text + +from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine +from app.dependencies import get_current_user +from app.db import get_db +from app.models.user_model import User +from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService + +from app.schemas.memory_reflection_schemas import Memory_Reflection + +load_dotenv() +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory", + tags=["Memory"], +) + + +@router.post("/reflection/save") +async def save_reflection_config( + request: Memory_Reflection, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Save reflection configuration to data_comfig table""" + + + + try: + config_id = request.config_id + if not config_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="缺少必需参数: config_id" + ) + + api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") + + update_params = { + "enable_self_reflexion": request.reflectionenabled, + "iteration_period": request.reflection_period_in_hours, + "reflexion_range": request.reflexion_range, + "baseline": request.baseline, + "reflection_model_id": request.reflection_model_id, + "memory_verify": request.memory_verify, + "quality_assessment": request.quality_assessment, + } + + + + query, params = DataConfigRepository.build_update_reflection(config_id, **update_params) + + result = db.execute(text(query), params) + if result.rowcount == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到config_id为 {config_id} 的配置" + ) + + db.commit() + + # 查询更新后的配置 + select_query, select_params = DataConfigRepository.build_select_reflection(config_id) + result = db.execute(text(select_query), select_params).fetchone() + + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"更新后未找到config_id为 {config_id} 的配置" + ) + + api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") + + # 返回结果 + return { + "status": "成功", + "message": "反思配置已保存", + "config_id": config_id, + "database_record": { + "config_id": result.config_id, + "enable_self_reflexion": result.enable_self_reflexion, + "iteration_period": result.iteration_period, + "reflexion_range": result.reflexion_range, + "baseline": result.baseline, + "reflection_model_id": result.reflection_model_id, + "memory_verify": result.memory_verify, + "quality_assessment": result.quality_assessment, + "user_id": result.user_id + } + } + + except ValueError as ve: + api_logger.error(f"参数错误: {str(ve)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"参数错误: {str(ve)}" + ) + except Exception as e: + api_logger.error(f"反思配置保存失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"反思配置保存失败: {str(e)}" + ) + + +@router.post("/reflection") +async def start_workspace_reflection( + request: dict, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Activate the reflection function for all matching applications in the workspace""" + workspace_id = current_user.current_workspace_id + reflection_service = MemoryReflectionService(db) + + try: + api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}") + + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(workspace_id) + + reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + return { + "status": "完成", + "message": f"成功处理 {len(reflection_results)} 个反思任务", + "workspace_id": str(workspace_id), + "reflection_count": len(reflection_results), + "reflection_results": reflection_results + } + + except Exception as e: + api_logger.error(f"启动workspace反思失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"启动workspace反思失败: {str(e)}" + ) + +@router.post("/reflection/run") +async def reflection_run( + reflection: Memory_Reflection, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Activate the reflection function for all matching applications in the workspace""" + config = ReflectionConfig( + enabled=reflection.reflectionenabled, + iteration_period=reflection.reflection_period_in_hours, + reflexion_range=reflection.reflexion_range, + baseline=reflection.baseline, + output_example='', + memory_verify=reflection.memory_verify, + quality_assessment=reflection.quality_assessment, + violation_handling_strategy="block", + model_id=reflection.reflection_model_id + ) + connector = Neo4jConnector() + engine = ReflectionEngine( + config=config, + neo4j_connector=connector, + llm_client=reflection.reflection_model_id # 传入 model_id + ) + + result=await (engine.reflection_run()) + return result diff --git a/api/app/core/config.py b/api/app/core/config.py index d4d285fe..bf5ff45a 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -148,6 +148,7 @@ class Settings: HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) + REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") diff --git a/api/app/core/memory/storage_services/reflection_engine/example/example.json b/api/app/core/memory/storage_services/reflection_engine/example/example.json new file mode 100644 index 00000000..6528da60 --- /dev/null +++ b/api/app/core/memory/storage_services/reflection_engine/example/example.json @@ -0,0 +1,210 @@ +{ + "memory_verify": { + "source_data": [ + { + "statement_name": "用户是2023年春天去北京工作的。", + "statement_id": "62beac695b1346f4871740a45db88782", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户后来基本一直都在北京上班。", + "statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户从2023年开始就一直在北京生活。", + "statement_id": "e612a44da4db483993c350df7c97a1a1", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户从来没有长期离开过北京。", + "statement_id": "b3c787a2e33c49f7981accabbbb4538a", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。", + "statement_id": "64cde4230cb24a4da726e7db9e7aa616", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。", + "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户的银行卡号是6222023847595898。", + "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户的身份信息和银行卡信息一直没变。", + "statement_id": "b3ca618e1e204b83bebd70e75cf2073f", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户认为在上海的那段时间更多算是远程配合。", + "statement_id": "150af89d2c154e6eb41ff1a91e37f962", + "statement_created_at": "2025-12-19T10:31:15.239252" + } + ], + "databasets": [ + { + "entity1_name": "Person", + "description": "表示人类个体的通用类型", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "用户", + "entity2": { + "entity_idx": 0, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Person", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "用户", + "apply_id": "88a459f5_text08", + "id": "3d3896797b334572a80d57590026063d" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "身份信息", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "Strong", + "description": "用于个人身份识别的数据", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Information", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "身份信息", + "apply_id": "88a459f5_text08", + "id": "aa766a517e82490599a9b3af54cfd933" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "6222023847595898", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "Strong", + "description": "用户的银行卡号码", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Numeric", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "6222023847595898", + "apply_id": "88a459f5_text08", + "id": "610ba361918f4e68a65ce6ad06e5c7a0" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "上海办公室", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "aliases": ["上海办"], + "connect_strength": "Strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "位于上海的工作办公场所", + "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Location", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "上海办公室", + "apply_id": "88a459f5_text08", + "id": "fb702ef695c14e14af3e56786bc8815b" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "北京", + "entity2": { + "entity_idx": 2, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "aliases": ["京", "京城", "北平"], + "connect_strength": "strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "中国的首都城市,用户主要工作和生活所在地", + "statement_id": "62beac695b1346f4871740a45db88782", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Location", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "北京", + "apply_id": "88a459f5_text08", + "id": "81b2d1a571bb46a08a2d7a1e87efb945" + } + }, + { + "entity1_name": "11010119950308123X", + "description": "具体的身份证号码值", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "身份证号", + "entity2": { + "entity_idx": 2, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "strong", + "description": "中华人民共和国公民的身份号码", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Identifier", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "身份证号", + "apply_id": "88a459f5_text08", + "id": "3e5f920645b2404fadb0e9ff60d1306e" + } + } + ] + } +} \ No newline at end of file diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index b3e5813d..8f5b9bae 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -8,17 +8,20 @@ 4. 反思结果应用 - 更新记忆库 """ -import os import json import logging import asyncio +import os +import time from typing import List, Dict, Any, Optional -from datetime import datetime from enum import Enum import uuid -from pydantic import BaseModel, Field +from pydantic import BaseModel +from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all +from app.repositories.neo4j.neo4j_update import neo4j_data +from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 配置日志 _root_logger = logging.getLogger() @@ -33,14 +36,14 @@ else: class ReflectionRange(str, Enum): """反思范围枚举""" - RETRIEVAL = "retrieval" # 从检索结果中反思 - DATABASE = "database" # 从整个数据库中反思 + PARTIAL = "partial" # 从检索结果中反思 + ALL = "all" # 从整个数据库中反思 class ReflectionBaseline(str, Enum): """反思基线枚举""" - TIME = "TIME" # 基于时间的反思 - FACT = "FACT" # 基于事实的反思 + TIME = "TIME" # 基于时间的反思 + FACT = "FACT" # 基于事实的反思 HYBRID = "HYBRID" # 混合反思 @@ -48,9 +51,16 @@ class ReflectionConfig(BaseModel): """反思引擎配置""" enabled: bool = False iteration_period: str = "3" # 反思周期 - reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL + reflexion_range: ReflectionRange = ReflectionRange.PARTIAL baseline: ReflectionBaseline = ReflectionBaseline.TIME - concurrency: int = Field(default=5, description="并发数量") + model_id: Optional[str] = None # 模型ID + end_user_id: Optional[str] = None + output_example: Optional[str] = None # 输出示例 + + # 评估相关字段 + memory_verify: bool = True # 记忆验证 + quality_assessment: bool = True # 质量评估 + violation_handling_strategy: str = "warn" # 违规处理策略 class Config: use_enum_values = True @@ -75,16 +85,16 @@ class ReflectionEngine: """ def __init__( - self, - config: ReflectionConfig, - neo4j_connector: Optional[Any] = None, - llm_client: Optional[Any] = None, - get_data_func: Optional[Any] = None, - render_evaluate_prompt_func: Optional[Any] = None, - render_reflexion_prompt_func: Optional[Any] = None, - conflict_schema: Optional[Any] = None, - reflexion_schema: Optional[Any] = None, - update_query: Optional[str] = None + self, + config: ReflectionConfig, + neo4j_connector: Optional[Any] = None, + llm_client: Optional[Any] = None, + get_data_func: Optional[Any] = None, + render_evaluate_prompt_func: Optional[Any] = None, + render_reflexion_prompt_func: Optional[Any] = None, + conflict_schema: Optional[Any] = None, + reflexion_schema: Optional[Any] = None, + update_query: Optional[str] = None ): """ 初始化反思引擎 @@ -109,7 +119,7 @@ class ReflectionEngine: self.conflict_schema = conflict_schema self.reflexion_schema = reflexion_schema self.update_query = update_query - self._semaphore = asyncio.Semaphore(config.concurrency) + self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 # 延迟导入以避免循环依赖 self._lazy_init_done = False @@ -127,11 +137,21 @@ class ReflectionEngine: from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.config import definitions as config_defs self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) + elif isinstance(self.llm_client, str): + # 如果 llm_client 是字符串(model_id),则用它初始化客户端 + from app.core.memory.utils.llm.llm_utils import get_llm_client + model_id = self.llm_client + self.llm_client = get_llm_client(model_id) if self.get_data_func is None: from app.core.memory.utils.config.get_data import get_data self.get_data_func = get_data + # 导入get_data_statement函数 + if not hasattr(self, 'get_data_statement'): + from app.core.memory.utils.config.get_data import get_data_statement + self.get_data_statement = get_data_statement + if self.render_evaluate_prompt_func is None: from app.core.memory.utils.prompt.template_render import render_evaluate_prompt self.render_evaluate_prompt_func = render_evaluate_prompt @@ -154,13 +174,11 @@ class ReflectionEngine: self._lazy_init_done = True - async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult: + async def execute_reflection(self, host_id) -> ReflectionResult: """ 执行完整的反思流程 - Args: host_id: 主机ID - Returns: ReflectionResult: 反思结果 """ @@ -176,9 +194,10 @@ class ReflectionEngine: start_time = asyncio.get_event_loop().time() logging.info("====== 自我反思流程开始 ======") + print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment) try: # 1. 获取反思数据 - reflexion_data = await self._get_reflexion_data(host_id) + reflexion_data, statement_databasets = await self._get_reflexion_data(host_id) if not reflexion_data: return ReflectionResult( success=True, @@ -187,22 +206,21 @@ class ReflectionEngine: ) # 2. 检测冲突(基于事实的反思) - conflict_data = await self._detect_conflicts(reflexion_data) - if not conflict_data: - return ReflectionResult( - success=True, - message="无冲突,无需反思", - execution_time=asyncio.get_event_loop().time() - start_time - ) + conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) + print(100 * '-') + print(conflict_data) + print(100 * '-') - conflicts_found = len(conflict_data) - logging.info(f"发现 {conflicts_found} 个冲突") + # 检查是否真的有冲突 + has_conflict = conflict_data[0].get('conflict', False) + conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 + logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") # 记录冲突数据 await self._log_data("conflict", conflict_data) # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data) + solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) if not solved_data: return ReflectionResult( success=False, @@ -210,6 +228,9 @@ class ReflectionEngine: conflicts_found=conflicts_found, execution_time=asyncio.get_event_loop().time() - start_time ) + print(100 * '*') + print(solved_data) + print(100 * '*') conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") @@ -230,7 +251,8 @@ class ReflectionEngine: conflicts_found=conflicts_found, conflicts_resolved=conflicts_resolved, memories_updated=memories_updated, - execution_time=execution_time + execution_time=execution_time, + ) except Exception as e: @@ -241,6 +263,79 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) + async def reflection_run(self): + self._lazy_init() + start_time = time.time() + + asyncio.get_event_loop().time() + logging.info("====== 自我反思流程开始 ======") + + result_data = {} + + source_data, databasets = await self.extract_fields_from_json() + result_data['baseline'] = self.config.baseline + result_data[ + 'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" + + # 2. 检测冲突(基于事实的反思) + conflict_data = await self._detect_conflicts(databasets, source_data) + # 遍历数据提取字段 + quality_assessments = [] + memory_verifies = [] + for item in conflict_data: + print(item) + quality_assessments.append(item['quality_assessment']) + memory_verifies.append(item['memory_verify']) + result_data['quality_assessments'] = quality_assessments + result_data['memory_verifies'] = memory_verifies + + # 检查是否真的有冲突 + has_conflict = conflict_data[0].get('conflict', False) + conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 + logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") + + # 记录冲突数据 + await self._log_data("conflict", conflict_data) + + # 3. 解决冲突 + solved_data = await self._resolve_conflicts(conflict_data, source_data) + if not solved_data: + return ReflectionResult( + success=False, + message="反思失败,未解决冲突", + conflicts_found=conflicts_found, + execution_time=asyncio.get_event_loop().time() - start_time + ) + reflexion_data = [] + + # 遍历数据提取reflexion字段 + for item in solved_data: + if 'results' in item: + for result in item['results']: + reflexion_data.append(result['reflexion']) + result_data['reflexion_data'] = reflexion_data + execution_time = time.time() - start_time + return {"status": "SUCCESS", "message": "反思试运行", "data": result_data, "time": execution_time} + + async def extract_fields_from_json(self): + """从example.json中提取source_data和databasets字段""" + + prompt_dir = os.path.join(os.path.dirname(__file__), "example") + try: + # 读取JSON文件 + with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f: + data = json.loads(f.read()) + + # 提取memory_verify下的字段 + memory_verify = data.get("memory_verify", {}) + source_data = memory_verify.get("source_data", []) + databasets = memory_verify.get("databasets", []) + + return source_data, databasets + + except Exception as e: + return [], [] + async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: """ 获取反思数据 @@ -253,17 +348,28 @@ class ReflectionEngine: Returns: List[Any]: 反思数据列表 """ - if self.config.reflexion_range == ReflectionRange.RETRIEVAL: - # 从检索结果中获取数据 - return await self.get_data_func(host_id) - elif self.config.reflexion_range == ReflectionRange.DATABASE: - # 从整个数据库中获取数据(待实现) - logging.warning("从数据库获取反思数据功能尚未实现") - return [] - else: - raise ValueError(f"未知的反思范围: {self.config.reflexion_range}") - async def _detect_conflicts(self, data: List[Any]) -> List[Any]: + + + if self.config.reflexion_range == ReflectionRange.PARTIAL: + neo4j_query = neo4j_query_part.format(host_id) + neo4j_statement = neo4j_statement_part.format(host_id) + elif self.config.reflexion_range == ReflectionRange.ALL: + neo4j_query = neo4j_query_all.format(host_id) + neo4j_statement = neo4j_statement_all.format(host_id) + try: + result = await self.neo4j_connector.execute_query(neo4j_query) + result_statement = await self.neo4j_connector.execute_query(neo4j_statement) + neo4j_databasets = await self.get_data_func(result) + neo4j_state = await self.get_data_statement(result_statement) + return neo4j_databasets, neo4j_state + + + except Exception as e: + logging.error(f"Neo4j查询失败: {e}") + return [], [] + + async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]: """ 检测冲突(基于事实的反思) @@ -278,14 +384,28 @@ class ReflectionEngine: if not data: return [] + # 数据预处理:如果数据量太少,直接返回无冲突 + if len(data) < 2: + logging.info("数据量不足,无需检测冲突") + return [] + + # 使用转换后的数据 + print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 + memory_verify = self.config.memory_verify + logging.info("====== 冲突检测开始 ======") start_time = asyncio.get_event_loop().time() + quality_assessment = self.config.quality_assessment try: # 渲染冲突检测提示词 rendered_prompt = await self.render_evaluate_prompt_func( data, - self.conflict_schema + self.conflict_schema, + self.config.baseline, + memory_verify, + quality_assessment, + statement_databasets ) messages = [{"role": "user", "content": rendered_prompt}] @@ -316,7 +436,7 @@ class ReflectionEngine: logging.error(f"冲突检测失败: {e}", exc_info=True) return [] - async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]: + async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]: """ 解决冲突 @@ -332,6 +452,8 @@ class ReflectionEngine: return [] logging.info("====== 冲突解决开始 ======") + baseline = self.config.baseline + memory_verify = self.config.memory_verify # 并行处理每个冲突 async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: @@ -341,7 +463,10 @@ class ReflectionEngine: # 渲染反思提示词 rendered_prompt = await self.render_reflexion_prompt_func( [conflict], - self.reflexion_schema + self.reflexion_schema, + baseline, + memory_verify, + statement_databasets ) messages = [{"role": "user", "content": rendered_prompt}] @@ -381,8 +506,8 @@ class ReflectionEngine: return solved async def _apply_reflection_results( - self, - solved_data: List[Dict[str, Any]] + self, + solved_data: List[Dict[str, Any]] ) -> int: """ 应用反思结果(更新记忆库) @@ -395,57 +520,7 @@ class ReflectionEngine: Returns: int: 成功更新的记忆数量 """ - if not solved_data: - logging.warning("无解决方案数据,跳过更新") - return 0 - - logging.info("====== 记忆更新开始 ======") - - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - """更新单条记忆""" - async with self._semaphore: - try: - if not isinstance(item, dict): - return False - - # 提取更新参数 - resolved = item.get("resolved", {}) - resolved_mem = resolved.get("resolved_memory", {}) - group_id = resolved_mem.get("group_id") - memory_id = resolved_mem.get("id") - new_invalid_at = resolved_mem.get("invalid_at") - - if not all([group_id, memory_id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - - # 执行更新 - await self.neo4j_connector.execute_query( - self.update_query, - group_id=group_id, - id=memory_id, - new_invalid_at=new_invalid_at, - ) - - return True - - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - # 并发执行所有更新任务 - tasks = [ - _update_one(item) - for item in solved_data - if isinstance(item, dict) - ] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆") - + success_count = await neo4j_data(solved_data) return success_count async def _log_data(self, label: str, data: Any) -> None: @@ -456,6 +531,7 @@ class ReflectionEngine: label: 数据标签 data: 要记录的数据 """ + def _write(): try: with open("reflexion_data.json", "a", encoding="utf-8") as f: @@ -470,9 +546,9 @@ class ReflectionEngine: # 基于时间的反思方法 async def time_based_reflection( - self, - host_id: uuid.UUID, - time_period: Optional[str] = None + self, + host_id: uuid.UUID, + time_period: Optional[str] = None ) -> ReflectionResult: """ 基于时间的反思 @@ -494,8 +570,8 @@ class ReflectionEngine: # 基于事实的反思方法 async def fact_based_reflection( - self, - host_id: uuid.UUID + self, + host_id: uuid.UUID ) -> ReflectionResult: """ 基于事实的反思 @@ -515,8 +591,8 @@ class ReflectionEngine: # 综合反思方法 async def comprehensive_reflection( - self, - host_id: uuid.UUID + self, + host_id: uuid.UUID ) -> ReflectionResult: """ 综合反思 @@ -553,33 +629,3 @@ class ReflectionEngine: else: raise ValueError(f"未知的反思基线: {self.config.baseline}") - -# 便捷函数:创建默认配置的反思引擎 -def create_reflection_engine( - enabled: bool = False, - iteration_period: str = "3", - reflexion_range: str = "retrieval", - baseline: str = "TIME", - concurrency: int = 5 -) -> ReflectionEngine: - """ - 创建反思引擎实例 - - Args: - enabled: 是否启用反思 - iteration_period: 反思周期 - reflexion_range: 反思范围 - baseline: 反思基线 - concurrency: 并发数量 - - Returns: - ReflectionEngine: 反思引擎实例 - """ - config = ReflectionConfig( - enabled=enabled, - iteration_period=iteration_period, - reflexion_range=reflexion_range, - baseline=baseline, - concurrency=concurrency - ) - return ReflectionEngine(config) diff --git a/api/app/core/memory/utils/config/get_data.py b/api/app/core/memory/utils/config/get_data.py index f2f21198..a099694e 100644 --- a/api/app/core/memory/utils/config/get_data.py +++ b/api/app/core/memory/utils/config/get_data.py @@ -1,13 +1,8 @@ import json -import os import uuid -from typing import List, Dict, Any, Optional -from sqlalchemy.orm import Session -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.schemas.memory_storage_schema import BaseDataSchema - import logging + +from typing import List, Dict, Any logger = logging.getLogger(__name__) async def _load_(data: List[Any]) -> List[Dict]: @@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]: return results -async def get_data(host_id: uuid.UUID) -> List[Dict]: +async def get_data(result): """ 从数据库中获取数据 """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all() + neo4j_databasets=[] + for item in result: + filtered_item = {} + for key, value in item.items(): + if 'name_embedding' not in key.lower(): + if key == 'relationship' and value is not None: + # 只保留relationship的指定字段 + rel_filtered = {} + if hasattr(value, 'get'): + rel_filtered['run_id'] = value.get('run_id') + rel_filtered['statement'] = value.get('statement') + rel_filtered['statement_id'] = value.get('statement_id') + rel_filtered['expired_at'] = value.get('expired_at') + rel_filtered['created_at'] = value.get('created_at') + filtered_item[key] = rel_filtered + elif key == 'entity2' and value is not None: + # 过滤entity2的name_embedding字段 + entity2_filtered = {} + if hasattr(value, 'items'): + for e_key, e_value in value.items(): + if 'name_embedding' not in e_key.lower(): + entity2_filtered[e_key] = e_value + filtered_item[key] = entity2_filtered + else: + filtered_item[key] = value + + # 直接将字典添加到列表中 + neo4j_databasets.append(filtered_item) + return neo4j_databasets +async def get_data_statement( result): + neo4j_databasets=[] + for i in result: + neo4j_databasets.append(i) + return neo4j_databasets + - # print(f"data:\n{data}") - # 解析,提取为字典的列表 - results = await _load_(data) - return results - except Exception as e: - logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass if __name__ == "__main__": diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index cb5b917d..e1ecf820 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -1,19 +1,222 @@ -你将收到一组记忆对象:{{ evaluate_data }}。 -任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突) +你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数: +原本的输入句子:{{statement_databasets}} +需要检测冲突对象:{{ evaluate_data }} +冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID) +记忆审核开关:{{ memory_verify }}(取值为 true / false) +记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false) -仅输出一个合法 JSON 对象,严格遵循下述结构: +你的任务是: +对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果 +数据的结构: + statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, + 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估) +## 冲突定义 + +### 时间冲突 +时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾: + +1. **同一活动的时间冲突**: + - 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球") + - 同一用户在同一时间段内被记录进行不同的互斥活动 + +2. **时间逻辑错误**: + - expired_at 早于 created_at + - 同一事实的 created_at 时间差异超过合理误差范围(>5分钟) + +3. **日期属性冲突**: + - 同一人的生日记录为不同日期(如"2月10号"和"2月16号") +4.存在明确先后约束 A -> B,但 t(A) > t(B) + -例:入学时间晚于毕业时间。 + -处理:标记异常、降权、触发逻辑反思或人工审查。 +5.时间属性冲突 + -单值日期属性出现多值(生日、入职日期) + -注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。 +6.互斥重叠冲突 + -例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地) + -处理:证据仲裁、保留多版本(active + candidate)。 + + + +### 事实冲突 +事实冲突是指同一实体的属性或关系存在相互矛盾的陈述: + +1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) +2. **关系矛盾**:同一实体在相同语境下的不同关系描述 +3. **身份冲突**:同一实体被赋予不同的类型或角色 + +### 混合冲突检测 +检测所有类型的冲突,包括但不限于时间冲突和事实冲突: +检测任何逻辑上不一致或相互矛盾的记录 +## 记忆审核定义 + +### 隐私信息检测(隐私冲突) +当memory_verify为true时,需要额外检测包含个人隐私信息的记录: + +1. **身份证信息**:包含身份证号码、身份证相关描述 +2. **手机号码**:包含手机号、电话号码等联系方式 +3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 +4. **银行信息**:包含银行卡号、账户信息、支付信息 +5. **税务信息**:包含税号、纳税信息、发票信息 +6. **贷款信息**:包含贷款记录、信贷信息、借款信息 +7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 + +### 隐私检测原则 +- 检测description、entity1_name、entity2_name等字段中的隐私信息 +- 识别数字模式(如手机号11位数字、身份证18位等) +- 识别关键词(如"身份证"、"银行卡"、"密码"等) +- 检测敏感实体类型和关系 + +## 冲突检测原则 + +**全面检测**:不区分冲突类型,检测所有可能的冲突 +**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段 +**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录 +**语义分析**:分析description字段的语义相似性和冲突性 +**时间逻辑**:检查时间字段的逻辑一致性 +**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录 + +## 不符合冲突检测 + -称呼 +## 重要检测示例 + +### 冲突检测示例 +- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号) +- 同一实体的重复定义但描述不同 +- 同一关系的不同表述但含义冲突 +- 任何逻辑上不可能同时为真的记录 + +### 隐私信息检测示例 +- 包含手机号的记录:"用户的手机号是13812345678" +- 包含身份证的记录:"身份证号码为110101199001011234" +- 包含银行卡的记录:"银行卡号6222021234567890" +- 包含社交账号的记录:"微信号是user123456" +- 包含敏感信息的实体名称或描述 + +## 输出要求 + +**关键原则**: +1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录 +2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中 +3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中 +4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组 +5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null +6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响 +7. 不输出conflict_memory字段 + +**处理逻辑**: +- 首先进行冲突检测,将冲突记录加入data数组 +- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组 +- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果 +- 最终data数组包含所有冲突记录和隐私信息记录(去重) +- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果 +- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述 + +返回数据格式以json方式输出: +- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 +- 关键的JSON格式要求{"statement":识别出的文本内容} +1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号 +2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们 +3.确保所有JSON字符串都正确关闭并以逗号分隔 +4.JSON字符串值中不包括换行符 +5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\"" +6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` + +## 记忆质量评估定义 + +### 质量评估标准 +当quality_assessment为true时,需要对记忆数据进行质量评估: + +1. **数据完整性**: + - 检查必要字段是否完整(entity1_name、entity2_name、description等) + - 检查关系描述是否清晰明确 + - 检查时间字段的有效性 + +2. **重复字段检测**: + - 识别相同或高度相似的记录 + - 检测冗余的实体关系 + - 分析描述内容的重复度 + +3. **无意义字段检测**: + - 识别空值、无效值或占位符内容 + - 检测过于简单或无信息量的描述 + - 识别格式错误或不规范的数据 + +4. **上下文依赖性**: + - 评估记录是否需要额外上下文才能理解 + - 检查实体名称的明确性 + - 分析关系描述的自包含性 + +### 质量评估输出 +- **质量百分比**:基于上述标准计算的整体质量分数(0-100) +- **质量概述**:简要描述数据质量状况,包括主要问题和优点 + +输出是仅输出一个合法 JSON 对象,严格遵循下述结构: { - "data": [ ...与输入同结构的记忆对象数组... ], - "conflict": true 或 false, - "conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null + "data": [ + { + "entity1_name": "实体1名称", + "description": "描述信息", + "statement_id": "陈述ID", + "created_at": "创建时间戳", + "expired_at": "过期时间戳", + "relationship_type": "关系类型", + "relationship": "关系对象", + "entity2_name": "实体2名称", + "entity2": "实体2对象" + } + ], + "conflict": true或false, + "quality_assessment": { + "score": 质量百分比数字, + "summary": "质量概述文本" + } 或 null, + "memory_verify": { + "has_privacy": true或false, + "privacy_types": ["检测到的隐私信息类型列表"], + "summary": "隐私检测结果概述" + } 或 null } 必须遵守: - 只输出 JSON,不要添加解释或多余文本。 - 使用标准双引号,必要时对内部引号进行转义。 - 字段名与结构必须与给定模式一致。 +- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。 +- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。 +- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。 + +### memory_verify字段说明 +当memory_verify为true时,需要输出隐私检测结果: +- **has_privacy**: 布尔值,表示是否检测到隐私信息 +- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"]) +- **summary**: 字符串,简要描述隐私检测结果 + +当memory_verify为false时,memory_verify字段输出null。 + +### memory_verify字段示例 + +**示例1:检测到隐私信息** +```json +"memory_verify": { + "has_privacy": true, + "privacy_types": ["手机号码", "身份证信息"], + "summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码" +} +``` + +**示例2:未检测到隐私信息** +```json +"memory_verify": { + "has_privacy": false, + "privacy_types": [], + "summary": "未检测到隐私信息" +} +``` + +**示例3:memory_verify为false时** +```json +"memory_verify": null +``` 模式参考: -[ - {{ json_schema }} -] \ No newline at end of file +{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 3f78b137..43e8e100 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -1,23 +1,300 @@ +你将收到一组用户历史记忆原始数据(来源于 Neo4j) 你将收到一条冲突判定对象:{{ data }}。 -任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。 +需要检测冲突对象:{{ statement_databasets }} +以及需要识别的冲突对象为:{{ baseline }} +记忆审核开关:{{ memory_verify }}(取值为 true / false) + +角色: +- 你是数据领域中解决数据冲突的专家 + +任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。 + +数据的结构: + statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, + 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间 + +**处理模式**: +- 当memory_verify为false时:仅处理数据冲突 +- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏 + +## 分组处理原则 + +**冲突类型识别与分组**: +1. **日期冲突**: + 1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号), + 1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球) +3. **事实属性冲突**: + 3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) + 3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述 + 3.3. **身份冲突**:同一实体被赋予不同的类型或角色 +4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别 + +**分组输出要求**: +- 每种冲突类型生成一个独立的reflexion_result对象 +- 同一类型的多个冲突记录归并到一个结果中 +- 不同类型的冲突分别处理,各自生成独立结果 + +## 冲突类型定义 + +### 时间冲突(TIME) +时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。 + +### 事实冲突(FACT) +事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。 +### 混合冲突(HYBRID) +检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录 +{% if memory_verify %} +## 隐私信息处理(memory_verify为true时启用) + +### 隐私信息识别 +需要识别并处理以下类型的隐私信息: + +1. **身份证信息**:包含身份证号码、身份证相关描述 +2. **手机号码**:包含手机号、电话号码等联系方式 +3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 +4. **银行信息**:包含银行卡号、账户信息、支付信息 +5. **税务信息**:包含税号、纳税信息、发票信息 +6. **贷款信息**:包含贷款记录、信贷信息、借款信息 +7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 + +### 隐私数据脱敏规则 +对于检测到的隐私信息,按以下规则进行脱敏处理: + +**数字类隐私信息脱敏**: +- 保留前三位和后四位,中间用*代替 +- 示例:手机号13812345678 → 138****5678 +- 示例:身份证110101199001011234 → 110***********1234 +- 示例:银行卡6222021234567890 → 622***********7890 + +**文本类隐私信息脱敏**: +- 社交账号:保留前三后四位字符,中间用*代替 +- 示例:微信号user123456 → use****3456 +- 示例:邮箱zhang.san@example.com → zha****@example.com + +**脱敏处理字段**: +- name字段:如包含隐私信息需脱敏 +- entity1_name字段:如包含隐私信息需脱敏 +- entity2_name字段:如包含隐私信息需脱敏 +- description字段:如包含隐私信息需脱敏 +{% endif %} + +## 工作步骤 + +### 第一步:分析冲突类型匹配 +首先判断输入的冲突数据是否符合baseline要求的类型: + +**类型匹配规则**: +- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突) +- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致) +- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理 + +**类型识别**: +- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等) +- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述 + +**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null) + +### 第二步:筛选并分组冲突数据 +按冲突类型对数据进行分组: + +**分组策略**: +1. **时间冲突组**:筛选涉及用户时间的所有记录 +2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录 +3. **事实冲突组**:筛选涉及同一实体不同属性的记录 +4. **其他冲突组**:其他类型的冲突记录 + +**筛选条件**: +- 只处理与baseline匹配的冲突类型 +- 相同entity1_name但entity2_name不同的记录 +- 相同关系但描述矛盾的记录 +- 时间逻辑不一致的记录 + +### 第三步:冲突解决策略 +** 不可以解决的冲突情况 + 1. 数据被判定为正确的情况下,不可以进行修改 +**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理: + +**智能解决策略**: +1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定 +2. **判断正确答案是否存在**: + - 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00) + - 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改 + - 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息 + +{% if memory_verify %} +**隐私处理集成**: +- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏 +- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏 +- 在change字段中记录隐私脱敏的变更 +{% endif %} + +**具体处理规则**: + +**情况1:正确答案存在于data中** +- 保留正确的记录不变 +- 基于时间关系的冲突: + 需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00) +- 基于事实的关系冲突 +- resolved.resolved_memory只包含被设为失效的错误记录 +- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间) + +**情况2:正确答案不存在于data中** +- 选择最合适的记录进行修改 +- 更新该记录的相关字段: + - description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %} + - name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %} +- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %} +- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]` + +**重要原则**: +- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据 +- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录 +- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值 +{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理 +- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %} +- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空 + +**变更记录格式**: +```json +"change": [ + { + "field": [ + {"字段名1": "修改后的值1"}, + {"字段名2": "修改后的值2"} + ] + } +] +``` + +**类型不匹配处理**: +- 如果冲突类型与baseline不匹配,resolved必须设为null +- reflexion.reason说明类型不匹配的原因 +- reflexion.solution说明无需处理 + +### 第四步:输出解决方案 + +## 输出要求 +**嵌套字段映射**(系统会自动处理): +- `entity2.name` → 自动映射为 `name` +- `entity1.name` → 自动映射为 `name` +- `entity1.description` → 自动映射为 `description` +- `entity2.description` → 自动映射为 `description` + +返回数据格式以json方式输出: +- 必须通过json.loads()的格式支持的形式输出 +- 响应必须是与此确切模式匹配的有效JSON对象 +- 不要在JSON之前或之后包含任何文本 + +JSON格式要求: +1. JSON结构仅使用标准ASCII双引号(") +2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义 +3. 确保所有JSON字符串都正确关闭并以逗号分隔 +4. JSON字符串值中不包括换行符 +5. 不允许输出```json```相关符号 仅输出一个合法 JSON 对象,严格遵循下述结构: + +**输出格式:按冲突类型分组的列表** { - "conflict": 与输入同结构,包含 data 与 conflict_memory, - "reflexion": { "reason": string, "solution": string }, - "resolved": { - "original_memory_id": 被设为失效的记忆 id, - "resolved_memory": 完整的设为失效后的记忆对象 - } + "results": [ + { + "conflict": { + "data": [该冲突类型相关的数据记录], + "conflict": true + }, + "reflexion": { + "reason": "该冲突类型的原因分析", + "solution": "该冲突类型的解决方案" + }, + "resolved": { + "original_memory_id": "被设为失效的记忆id", + "resolved_memory": { + "entity1_name": "实体1名称", + "entity2_name": "实体2名称", + "description": "描述信息", + "statement_id": "陈述ID", + "created_at": "创建时间", + "expired_at": "过期时间", + "relationship_type": "关系类型", + "relationship": {}, + "entity2": {...} + }, + "change": [ + { + "field": [ + {"字段名1": "修改后的值1"}, + {"字段名2": "修改后的值2"} + ] + } + ] + }, + "type": "reflexion_result" + } + ] +} + +**示例:多种冲突类型的输出** +{ + "results": [ + { + "conflict": { + "data": [生日冲突相关的记录], + "conflict": true + }, + "reflexion": { + "reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期", + "solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效" + }, + "resolved": { + "original_memory_id": "df066210883545a08e727ccd8ad4ec77", + "resolved_memory": {...}, + "change": [ + { + "field": [ + {"expired_at": "2025-12-16T12:00:00"} + ] + } + ] + }, + "type": "reflexion_result" + }, + { + "conflict": { + "data": [篮球时间冲突相关的记录], + "conflict": true + }, + "reflexion": { + "reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突", + "solution": "保留最可信的时间记录,将冲突记录设为失效" + }, + "resolved": { + "original_memory_id": "另一个记录ID", + "resolved_memory": {...}, + "change": [ + { + "field": [ + {"description": "使用系统的个人,指代说话者本人,篮球时间为周六"}, + {"entity2_name": "周六"} + ] + } + ] + }, + "type": "reflexion_result" + } + ] } 必须遵守: -- 只输出 JSON,不要添加解释或多余文本。 -- 使用标准双引号,必要时对内部引号进行转义。 -- 字段名与结构必须与给定模式一致。 -- 当 conflict 为 false 时,resolved 必须为 null。 - - 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。 +- 只输出 JSON,不要添加解释或多余文本 +- 使用标准双引号,必要时对内部引号进行转义 +- 字段名与结构必须与给定模式一致 +- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象 +- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中 +- **每个result对象的conflict.data**只包含该冲突类型相关的记录 +- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出 +- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值 +- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null +- 如果与baseline不匹配的冲突类型,不要在results中包含该类型 + 模式参考: -[ - {{ json_schema }} -] +{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index c783e095..818d456a 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -7,36 +7,50 @@ from typing import List, Dict, Any prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) -async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str: +async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], + baseline: str = "TIME", + memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str: """ - Renders the evaluate prompt using the evaluate.jinja2 template. + Renders the evaluate prompt using the evaluate_optimized.jinja2 template. Args: evaluate_data: The data to evaluate schema: The JSON schema to use for the output. + baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT) + memory_verify: Whether to enable memory verification for privacy detection Returns: Rendered prompt content as string """ template = prompt_env.get_template("evaluate.jinja2") - rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema) - + rendered_prompt = template.render( + evaluate_data=evaluate_data, + json_schema=schema, + baseline=baseline, + memory_verify=memory_verify, + quality_assessment=quality_assessment, + statement_databasets=statement_databasets + ) return rendered_prompt -async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str: +async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, + statement_databasets: List[str] = []) -> str: """ - Renders the reflexion prompt using the extract_temporal.jinja2 template. + Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Args: data: The data to reflex on. schema: The JSON schema to use for the output. + baseline: The baseline type for conflict resolution. Returns: Rendered prompt content as a string. """ template = prompt_env.get_template("reflexion.jinja2") - rendered_prompt = template.render(data=data, json_schema=schema) + rendered_prompt = template.render(data=data, json_schema=schema, + baseline=baseline,memory_verify=memory_verify, + statement_databasets=statement_databasets) return rendered_prompt diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py index 9f27562c..be43bd8d 100644 --- a/api/app/models/data_config_model.py +++ b/api/app/models/data_config_model.py @@ -1,5 +1,4 @@ import datetime -import uuid from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float from sqlalchemy.dialects.postgresql import UUID from app.db import Base @@ -11,50 +10,53 @@ class DataConfig(Base): # 主键 config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID") - + # 基本信息 config_name = Column(String, nullable=False, comment="配置名称") config_desc = Column(String, nullable=True, comment="配置描述") - + # 组织信息 workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID") group_id = Column(String, nullable=True, comment="组ID") user_id = Column(String, nullable=True, comment="用户ID") apply_id = Column(String, nullable=True, comment="应用ID") - + # 模型选择(从workspace继承) llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") llm = Column(String, nullable=True, comment="LLM模型配置ID") - + # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧") deep_retrieval = Column(Boolean, default=True, comment="深度检索开关") - + # 阈值配置 (0-1 之间的浮点数) t_type_strict = Column(Float, default=0.8, comment="类型严格阈值") t_name_strict = Column(Float, default=0.8, comment="名称严格阈值") t_overall = Column(Float, default=0.8, comment="综合阈值") - + # 状态配置 state = Column(Boolean, default=False, comment="配置使用状态") - + # 分块策略 chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") - + # 剪枝配置 pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝") pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound") pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)") - + # 自我反思配置 enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思") iteration_period = Column(String, default="3", comment="反思迭代周期") reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部") baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实") - + reflection_model_id = Column(String, nullable=True, comment="反思模型ID") + memory_verify = Column(Boolean, default=True, comment="记忆验证") + quality_assessment = Column(Boolean, default=True, comment="质量评估") + # 遗忘引擎配置 statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3") include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文") @@ -62,7 +64,7 @@ class DataConfig(Base): lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") - + # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index a2c02f84..2a9ed8da 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -14,6 +14,7 @@ class EndUser(Base): other_id = Column(String, nullable=True) # Store original user_id other_name = Column(String, default="", nullable=False) other_address = Column(String, default="", nullable=False) + reflection_time = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.datetime.now) updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index ed1a482a..6b281ef1 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -16,48 +16,46 @@ import uuid from app.models.data_config_model import DataConfig from app.schemas.memory_storage_schema import ( ConfigParamsCreate, - ConfigParamsDelete, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, - ConfigKey, ) from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() - +TABLE_NAME = "data_config" class DataConfigRepository: """数据配置Repository - + 提供data_config表的数据访问方法,包括: - SQLAlchemy ORM 数据库操作 - Neo4j Cypher查询常量 """ - + # ==================== Neo4j Cypher 查询常量 ==================== - + # Dialogue count by group SEARCH_FOR_DIALOGUE = """ MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # Chunk count by group SEARCH_FOR_CHUNK = """ MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # Statement count by group SEARCH_FOR_STATEMENT = """ MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # ExtractedEntity count by group SEARCH_FOR_ENTITY = """ MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # All counts by label and total SEARCH_FOR_ALL = """ OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count @@ -70,7 +68,7 @@ class DataConfigRepository: UNION ALL OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count """ - + # Extracted entity details within group/app/user SEARCH_FOR_DETIALS = """ MATCH (n:ExtractedEntity) @@ -86,7 +84,7 @@ class DataConfigRepository: n.user_id AS user_id, n.id AS id """ - + # Edges between extracted entities within group/app/user SEARCH_FOR_EDGES = """ MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) @@ -102,7 +100,7 @@ class DataConfigRepository: r.statement_id AS statement_id, r.statement AS statement """ - + # Entity graph within group (source node, edge, target node) SEARCH_FOR_ENTITY_GRAPH = """ MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) @@ -135,22 +133,106 @@ class DataConfigRepository: id: m.id } AS targetNode """ - + # ==================== SQLAlchemy ORM 数据库操作方法 ==================== - + @staticmethod + def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]: + """构建反思配置更新语句(SQLAlchemy text() 命名参数) + + Args: + config_id: 配置ID + **kwargs: 反思配置参数 + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + + Raises: + ValueError: 没有字段需要更新时抛出 + """ + db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") + + key_where = "config_id = :config_id" + set_fields: List[str] = [] + params: Dict = { + "config_id": config_id, + } + + # 反思配置字段映射 + mapping = { + "enable_self_reflexion": "enable_self_reflexion", + "iteration_period": "iteration_period", + "reflexion_range": "reflexion_range", + "baseline": "baseline", + "reflection_model_id": "reflection_model_id", + "memory_verify": "memory_verify", + "quality_assessment": "quality_assessment", + } + + for api_field, db_col in mapping.items(): + if api_field in kwargs and kwargs[api_field] is not None: + set_fields.append(f"{db_col} = :{api_field}") + params[api_field] = kwargs[api_field] + + if not set_fields: + raise ValueError("No fields to update") + + set_fields.append("updated_at = timezone('Asia/Shanghai', now())") + query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" + return query, params + + @staticmethod + def build_select_reflection(config_id: int) -> Tuple[str, Dict]: + """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) + + Args: + config_id: 配置ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") + + query = ( + f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, " + f"reflection_model_id, memory_verify, quality_assessment, user_id " + f"FROM {TABLE_NAME} WHERE config_id = :config_id" + ) + params = {"config_id": config_id} + return query, params + + @staticmethod + def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: + """构建查询所有配置的语句(SQLAlchemy text() 命名参数) + + Args: + workspace_id: 工作空间ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") + + query = ( + f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, " + f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at " + f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC" + ) + params = {"workspace_id": workspace_id} + return query, params + @staticmethod def create(db: Session, params: ConfigParamsCreate) -> DataConfig: """创建数据配置 - + Args: db: 数据库会话 params: 配置参数创建模型 - + Returns: DataConfig: 创建的配置对象 """ db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}") - + try: db_config = DataConfig( config_name=params.config_name, @@ -162,37 +244,37 @@ class DataConfigRepository: ) db.add(db_config) db.flush() # 获取自增ID但不提交事务 - + db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})") return db_config - + except Exception as e: db.rollback() db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}") raise - + @staticmethod def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]: """更新基础配置 - + Args: db: 数据库会话 update: 配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新数据配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段 has_update = False if update.config_name is not None: @@ -201,44 +283,44 @@ class DataConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]: """更新记忆萃取引擎配置 - + Args: db: 数据库会话 update: 萃取配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段映射 field_mapping = { # 模型选择 @@ -268,50 +350,50 @@ class DataConfigRepository: "reflexion_range": "reflexion_range", "baseline": "baseline", } - + has_update = False for api_field, db_field in field_mapping.items(): value = getattr(update, api_field, None) if value is not None: setattr(db_config, db_field, value) has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"萃取配置更新成功: config_id={update.config_id}") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]: """更新遗忘引擎配置 - + Args: db: 数据库会话 update: 遗忘配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新遗忘配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段 has_update = False if update.lambda_time is not None: @@ -323,40 +405,40 @@ class DataConfigRepository: if update.offset is not None: db_config.offset = update.offset has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]: """获取萃取配置,通过主键查询某条配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[Dict]: 萃取配置字典,不存在则返回None """ db_logger.debug(f"查询萃取配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"萃取配置不存在: config_id={config_id}") return None - + result = { "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, @@ -379,62 +461,62 @@ class DataConfigRepository: "reflexion_range": db_config.reflexion_range, "baseline": db_config.baseline, } - + db_logger.debug(f"萃取配置查询成功: config_id={config_id}") return result - + except Exception as e: db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_forget_config(db: Session, config_id: int) -> Optional[Dict]: """获取遗忘配置,通过主键查询某条配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[Dict]: 遗忘配置字典,不存在则返回None """ db_logger.debug(f"查询遗忘配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"遗忘配置不存在: config_id={config_id}") return None - + result = { "lambda_time": db_config.lambda_time, "lambda_mem": db_config.lambda_mem, "offset": db_config.offset, } - + db_logger.debug(f"遗忘配置查询成功: config_id={config_id}") return result - + except Exception as e: db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]: """根据ID获取数据配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[DataConfig]: 配置对象,不存在则返回None """ db_logger.debug(f"根据ID查询数据配置: config_id={config_id}") - + try: config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - + if config: db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})") else: @@ -443,60 +525,60 @@ class DataConfigRepository: except Exception as e: db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]: """获取所有配置参数 - + Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 - + Returns: List[DataConfig]: 配置列表 """ db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") - + try: query = db.query(DataConfig) - + if workspace_id: query = query.filter(DataConfig.workspace_id == workspace_id) - + configs = query.order_by(desc(DataConfig.updated_at)).all() - + db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") return configs - + except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") raise - + @staticmethod def delete(db: Session, config_id: int) -> bool: """删除数据配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: bool: 删除成功返回True,配置不存在返回False """ db_logger.debug(f"删除数据配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={config_id}") return False - + db.delete(db_config) db.commit() - + db_logger.info(f"数据配置删除成功: config_id={config_id}") return True - + except Exception as e: db.rollback() db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}") diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 7330a00f..95e2ee03 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -746,3 +746,57 @@ DETACH DELETE losing RETURN count(losing) as deleted """ + +neo4j_statement_part = ''' +MATCH (n:Statement) +WHERE n.group_id = "{}" + AND datetime(n.created_at) >= datetime() - duration('P3D') +RETURN + n.statement as statement_name, + n.id as statement_id, + n.created_at as statement_created_at + +''' +neo4j_statement_all = ''' +MATCH (n:Statement) +WHERE n.group_id = "{}" +RETURN + n.statement as statement_name, + n.id as statement_id + +''' +neo4j_query_part = """ + MATCH (n)-[r]-(m:ExtractedEntity) + WHERE n.group_id = "{}" + AND datetime(n.created_at) >= datetime() - duration('P3D') + WITH DISTINCT m + OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) + RETURN + m.name as entity1_name, + m.description as description, + m.statement_id as statement_id, + m.created_at as created_at, + m.expired_at as expired_at, + CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + rel as relationship, + CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, + other as entity2 + """ +neo4j_query_all = """ + MATCH (n)-[r]-(m:ExtractedEntity) + WHERE n.group_id = "{}" + WITH DISTINCT m + OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) + RETURN + m.name as entity1_name, + m.description as description, + m.statement_id as statement_id, + m.created_at as created_at, + m.expired_at as expired_at, + CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + rel as relationship, + CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, + other as entity2 + """ + + diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py new file mode 100644 index 00000000..9644224c --- /dev/null +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -0,0 +1,227 @@ +from app.repositories import Neo4jConnector + +neo4j_connector = Neo4jConnector() + +async def update_neo4j_data(neo4j_dict_data, update_databases): + """ + Update Neo4j data based on query criteria and update parameters + + Args: + neo4j_dict_data: find + update_databases: update + """ + try: + # 构建WHERE条件 + where_conditions = [] + params = {} + + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value + + where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" + + # 构建SET条件 + set_conditions = [] + for key, value in update_databases.items(): + if value is not None: + param_name = f"update_{key}" + set_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value + + set_clause = ", ".join(set_conditions) + + if not set_clause: + print("警告: 没有需要更新的字段") + return False + + # 构建Cypher查询 + cypher_query = f""" + MATCH (e:ExtractedEntity) + WHERE {where_clause} + SET {set_clause} + RETURN count(e) as updated_count, collect(e.name) as updated_names + """ + + print(f"\n执行Cypher查询: {cypher_query}") + print(f"参数: {params}") + + # 执行更新 + result = await neo4j_connector.execute_query(cypher_query, **params) + + if result: + updated_count = result[0].get('updated_count', 0) + updated_names = result[0].get('updated_names', []) + print(f"成功更新 {updated_count} 个节点") + if updated_names: + print(f"更新的实体名称: {updated_names}") + return updated_count > 0 + else: + return False + + except Exception as e: + print(f"更新过程中出现错误: {e}") + import traceback + traceback.print_exc() + return False + + +def map_field_names(data_dict): + mapped_dict = {} + has_name_field = False + + # 第一遍:检查是否有name相关字段 + for key, value in data_dict.items(): + if key in ['name', 'entity2.name', 'entity1.name']: + has_name_field = True + break + + print(f"字段检查: has_name_field = {has_name_field}") + + # 第二遍:根据规则映射和过滤字段 + for key, value in data_dict.items(): + if key == 'entity2.name' or key == 'entity2_name': + # 将 entity2.name 映射为 name + mapped_dict['name'] = value + print(f"字段名映射: {key} -> name") + elif key == 'entity1.name' or key == 'entity1_name': + # 将 entity1.name 映射为 name + mapped_dict['name'] = value + print(f"字段名映射: {key} -> name") + elif key == 'entity1.description': + # 将 entity1.description 映射为 description + mapped_dict['description'] = value + print(f"字段名映射: {key} -> description") + elif key == 'entity2.description': + # 将 entity2.description 映射为 description + mapped_dict['description'] = value + print(f"字段名映射: {key} -> description") + elif key == 'relationship_type': + # 跳过relationship_type字段 + print(f"字段过滤: 跳过不需要的字段 '{key}'") + continue + elif key == 'entity1_name': + if has_name_field: + # 如果有name字段,跳过entity1_name + print(f"字段过滤: 由于存在name字段,跳过 '{key}'") + continue + else: + # 如果没有name字段,保留entity1_name + mapped_dict[key] = value + print(f"字段保留: {key}") + elif key == 'entity2_name': + if has_name_field: + # 如果有name字段,跳过entity2_name + print(f"字段过滤: 由于存在name字段,跳过 '{key}'") + continue + else: + # 即使没有name字段,也不使用entity2_name(根据需求) + print(f"字段过滤: 跳过不推荐的字段 '{key}'") + continue + elif '.' not in key: + # 不包含点号的其他字段直接保留 + mapped_dict[key] = value + else: + # 其他包含点号的字段跳过并警告 + print(f"警告: 跳过不支持的嵌套字段 '{key}'") + + print(f"字段映射结果: {mapped_dict}") + return mapped_dict +async def neo4j_data(solved_data): + """ + Process the resolved data and update the Neo4j database + Args: + Solved_data: Solution Data List + Returns: + Int: Number of successfully updated records + """ + success_count = 0 + + for i in solved_data: + neo4j_dict_data = {} + update_databases = {} + results = i['results'] + for data in results: + resolved = data.get('resolved') + if not resolved: + print("跳过:resolved为None") + continue + + try: + change_list = resolved.get('change', []) + except (AttributeError, TypeError): + change_list = [] + + if change_list == []: + print("跳过:change_list为空") + continue + + if change_list and len(change_list) > 0: + change = change_list[0] + print(f"change: {change}") + field_data = change.get('field', []) + print(f"field_data: {field_data}") + print(f"field_data type: {type(field_data)}") + + # 字段名映射和过滤函数 + + + # 处理field数据,可能是字典或列表 + if isinstance(field_data, dict): + # 如果是字典,映射字段名后更新 + mapped_data = map_field_names(field_data) + update_databases.update(mapped_data) + elif isinstance(field_data, list): + # 如果是列表,遍历每个字典并更新 + for field_item in field_data: + if isinstance(field_item, dict): + mapped_item = map_field_names(field_item) + update_databases.update(mapped_item) + else: + print(f"警告: field_item不是字典: {field_item}") + else: + print(f"警告: field_data类型不支持: {type(field_data)}") + + if 'entity1_name' in data: + data['name'] = data.pop('entity1_name') + if 'entity2_name' in data: + data.pop('entity2_name', None) + + resolved_memory = resolved.get('resolved_memory', {}) + + entity2 = None + if isinstance(resolved_memory, dict): + entity2 = resolved_memory.get('entity2') + + if entity2 and isinstance(entity2, dict) and len(entity2) >= 5: + stat_id = resolved.get('original_memory_id') + # 安全地获取description + statement_id = None + if isinstance(resolved_memory, dict): + statement_id = resolved_memory.get('statement_id') + + # 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id + if statement_id and 'id' not in neo4j_dict_data: + neo4j_dict_data['id'] = stat_id + neo4j_dict_data['statement_id'] = statement_id + else: + # 处理original_memory_id,它可能是字符串或字典 + try: + for key, value in resolved_memory.items(): + if key == 'statement_id': + neo4j_dict_data['statement_id'] = value + if key == 'description': + neo4j_dict_data['description'] = value + except AttributeError: + neo4j_dict_data=[] + + print(neo4j_dict_data) + print(update_databases) + if neo4j_dict_data!=[]: + await update_neo4j_data(neo4j_dict_data, update_databases) + success_count += 1 + + return success_count + diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 30dafddd..74fc4a14 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -13,5 +13,6 @@ class EndUser(BaseModel): other_id: Optional[str] = Field(description="第三方ID", default=None) other_name: Optional[str] = Field(description="其他名称", default="") other_address: Optional[str] = Field(description="其他地址", default="") + reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now) created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) diff --git a/api/app/schemas/memory_reflection_schemas.py b/api/app/schemas/memory_reflection_schemas.py new file mode 100644 index 00000000..9eb11c6c --- /dev/null +++ b/api/app/schemas/memory_reflection_schemas.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field +from typing import Optional +from enum import Enum + + +class OptimizationStrategy(str, Enum): + """优化策略枚举""" + SPEED_FIRST = "speed_first" + ACCURACY_FIRST = "accuracy_first" + BALANCED = "balanced" + + +class Memory_Reflection(BaseModel): + config_id: Optional[int] = None + reflectionenabled: bool + reflection_period_in_hours: str + reflexion_range: str + baseline: str + reflection_model_id: str + memory_verify: bool + quality_assessment: bool + + # 新增快速引擎优化参数 + optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED + use_fast_model: Optional[bool] = True + enable_caching: Optional[bool] = True + enable_streaming: Optional[bool] = True + batch_size: Optional[int] = Field(default=3, ge=1, le=10) + max_concurrent: Optional[int] = Field(default=5, ge=1, le=20) + + class Config: + use_enum_values = True + + +class FastReflectionRequest(BaseModel): + """快速反思请求模型""" + reflection: Memory_Reflection + host_id: Optional[str] = "88a459f5_text02" + optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED + + class Config: + use_enum_values = True + + +class ReflectionBenchmarkRequest(BaseModel): + """反思基准测试请求模型""" + reflection: Memory_Reflection + host_id: Optional[str] = "88a459f5_text02" + iterations: Optional[int] = Field(default=3, ge=1, le=10) + + class Config: + use_enum_values = True + + diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 66b2e45f..ab6b0512 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -2,7 +2,7 @@ 所有的内容是放错误地方了,应该放在models """ -from typing import Any, Optional, List, Dict, Literal +from typing import Any, Optional, List, Dict, Literal, Union import time import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator @@ -28,25 +28,48 @@ class Write_UserInput(BaseModel): # ============================================================================ class BaseDataSchema(BaseModel): """Base schema for the data""" - id: str = Field(..., description="The unique identifier for the data entry.") - statement: str = Field(..., description="The statement text.") - group_id: str = Field(..., description="The group identifier.") - chunk_id: str = Field(..., description="The chunk identifier.") + # 保持原有必需字段为可选,以兼容不同数据源 + id: Optional[str] = Field(None, description="The unique identifier for the data entry.") + statement: Optional[str] = Field(None, description="The statement text.") + group_id: Optional[str] = Field(None, description="The group identifier.") + chunk_id: Optional[str] = Field(None, description="The chunk identifier.") created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.") invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.") entity_ids: List[str] = Field([], description="The list of entity identifiers.") + description: Optional[str] = Field(None, description="The description of the data entry.") + + # 新增字段以匹配实际输入数据 + entity1_name: str = Field(..., description="The first entity name.") + entity2_name: Optional[str] = Field(None, description="The second entity name.") + statement_id: str = Field(..., description="The statement identifier.") + relationship_type: str = Field(..., description="The relationship type.") + relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.") + entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") + + +class QualityAssessmentSchema(BaseModel): + """Schema for memory quality assessment results.""" + score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).") + summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.") + + +class MemoryVerifySchema(BaseModel): + """Schema for memory privacy verification results.""" + has_privacy: bool = Field(..., description="Whether privacy information was detected.") + privacy_types: List[str] = Field([], description="List of detected privacy information types.") + summary: str = Field(..., description="Brief summary of privacy detection results.") class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data.") + data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") - @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -61,7 +84,6 @@ class ConflictSchema(BaseModel): conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") @model_validator(mode="before") - @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel): solution: str = Field(..., description="The solution for the reflexion.") +class ChangeRecordSchema(BaseModel): + """Schema for individual change records""" + field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.") + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") - resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.") + # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") +class SingleReflexionResultSchema(BaseModel): + """Schema for a single reflexion result item.""" + conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.") + reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.") + resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") + type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): - """Schema for the reflexion result data in the reflexion_data.json file.""" - # 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射 - conflict: ConflictResultSchema = Field(..., description="The conflict result data.") - reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.") - resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.") + """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" + results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") - @classmethod def _normalize_resolved(cls, v): if isinstance(v, dict): conflict = v.get("conflict") diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py new file mode 100644 index 00000000..0f8fb569 --- /dev/null +++ b/api/app/services/memory_reflection_service.py @@ -0,0 +1,397 @@ +""" +记忆反思服务 +处理反思引擎的调用和执行 +""" +from datetime import datetime +from typing import Dict, Any, Optional, Set + +from fastapi import Depends +from sqlalchemy.orm import Session +from sqlalchemy import text + +from app.db import get_db +from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.reflection_engine import ReflectionConfig, ReflectionEngine +from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionRange, ReflectionBaseline +from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.models.app_model import App +from app.models.app_release_model import AppRelease +from app.models.end_user_model import EndUser + +api_logger = get_api_logger() + + +class WorkspaceAppService: + """Workplace Application Service Class """ + + def __init__(self, db: Session): + self.db = db + + def get_workspace_apps_detailed(self, workspace_id: str) -> Dict[str, Any]: + """ + Get detailed information of all applications in the workspace + + Args: + Workspace_id: Workspace ID + + Returns: + Dictionary containing detailed application information + """ + apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() + app_ids = [str(app.id) for app in apps] + + apps_detailed_info = [] + + for app in apps: + app_info = self._build_app_info(app) + self._process_app_releases(app, app_info) + self._process_end_users(app, app_info) + apps_detailed_info.append(app_info) + + return { + "status": "成功", + "message": f"成功查询到 {len(app_ids)} 个应用及其详细信息", + "workspace_id": str(workspace_id), + "apps_count": len(app_ids), + "app_ids": app_ids, + "apps_detailed_info": apps_detailed_info + } + + def _build_app_info(self, app: App) -> Dict[str, Any]: + """base_infomation""" + return { + "id": str(app.id), + "name": app.name, + "description": app.description, + "type": app.type, + "status": app.status, + "visibility": app.visibility, + "created_at": app.created_at.isoformat() if app.created_at else None, + "updated_at": app.updated_at.isoformat() if app.updated_at else None, + "releases": [], + "data_configs": [], + "end_users": [] + } + + def _process_app_releases(self, app: App, app_info: Dict[str, Any]) -> None: + """Process the release version and configuration information of the application""" + app_releases = self.db.query(AppRelease).filter(AppRelease.app_id == app.id).all() + + if not app_releases: + return + + processed_configs: Set[str] = set() + + for release in app_releases: + memory_content = self._extract_memory_content(release.config) + + + if memory_content and memory_content in processed_configs: + continue + + release_info = { + "app_id": str(release.app_id), + "config": memory_content + } + + + if memory_content: + processed_configs.add(memory_content) + data_config_info = self._get_data_config(memory_content) + + if data_config_info: + if not any(dc["config_id"] == data_config_info["config_id"] for dc in app_info["data_configs"]): + app_info["data_configs"].append(data_config_info) + + app_info["releases"].append(release_info) + + def _extract_memory_content(self, config: Any) -> str: + """Extract memory_comtent from config""" + if not config or not isinstance(config, dict): + return None + + memory_obj = config.get('memory') + if memory_obj and isinstance(memory_obj, dict): + return memory_obj.get('memory_content') + + return None + + def _get_data_config(self, memory_content: str) -> Dict[str, Any]: + """Retrieve data_comfig information based on memory_comtent""" + try: + data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) + data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() + if data_config_result is None: + return None + + if data_config_result: + return { + "config_id": data_config_result.config_id, + "enable_self_reflexion": data_config_result.enable_self_reflexion, + "iteration_period": data_config_result.iteration_period, + "reflexion_range": data_config_result.reflexion_range, + "baseline": data_config_result.baseline, + "reflection_model_id": data_config_result.reflection_model_id, + "memory_verify": data_config_result.memory_verify, + "quality_assessment": data_config_result.quality_assessment, + "user_id": data_config_result.user_id + } + except Exception as e: + api_logger.warning(f"查询data_config失败,memory_content: {memory_content}, 错误: {str(e)}") + + return None + + def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None: + """Processing end-user information for applications""" + end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all() + + for end_user in end_users: + end_user_info = { + "id": str(end_user.id), + "app_id": str(end_user.app_id) + } + app_info["end_users"].append(end_user_info) + + def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]: + """ + Read the reflection time of end users + + Args: + End_user_id: End User ID + + Returns: + Reflection time or None + """ + try: + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first() + if end_user: + return end_user.reflection_time + return None + except Exception as e: + api_logger.error(f"读取用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}") + return None + + def update_end_user_reflection_time(self, end_user_id: str) -> bool: + """ + Update the reflection time of end users to the current time + + Args: + End_user_id: End User ID + + Returns: + Is the update successful + """ + try: + from datetime import datetime + + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first() + if end_user: + end_user.reflection_time = datetime.now() + self.db.commit() + api_logger.info(f"成功更新用户反思时间,end_user_id: {end_user_id}") + return True + else: + api_logger.warning(f"未找到用户,end_user_id: {end_user_id}") + return False + except Exception as e: + api_logger.error(f"更新用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}") + self.db.rollback() + return False + + +class MemoryReflectionService: + """Memory reflection service category""" + + def __init__(self,db: Session = Depends(get_db)): + self.db=db + + + async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + """ + Starting Reflection from Configuration Data + + Args: + config_data: Configure data dictionary, including reflective configuration information + end_user_id: end_user_id + + Returns: + Reflect on the execution results + """ + try: + config_id = config_data.get("config_id") + api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}") + + + if not config_data.get("enable_self_reflexion", False): + return { + "status": "跳过", + "message": "反思引擎未启用", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + + config_data_id=config_data['config_id'] + reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id) + if reflection_config is not None and reflection_config['enable_self_reflexion']: + reflection_config= self._create_reflection_config_from_data(reflection_config) + iteration_period=reflection_config.iteration_period + workspace_service = WorkspaceAppService(self.db) + current_reflection_time = workspace_service.get_end_user_reflection_time(end_user_id) + + reflection_time = datetime.fromisoformat(str(current_reflection_time)) + + current_time = datetime.now() + time_diff = current_time - reflection_time + hours_diff = int(time_diff.total_seconds() / 3600) + if iteration_period==hours_diff or current_reflection_time is None: + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") + # 3. 执行反思引擎 + reflection_results = await self._execute_reflection_engine( + reflection_config, end_user_id + ) + # 更新反思时间为当前时间 + update_success = workspace_service.update_end_user_reflection_time(end_user_id) + if update_success: + api_logger.info(f"成功更新用户 {end_user_id} 的反思时间") + else: + api_logger.error(f"更新用户 {end_user_id} 的反思时间失败") + + return { + "status": "完成", + "message": "反思引擎执行完成", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": reflection_results + } + else: + return { + "status": "等待中..", + "message": "反思引擎未开始执行执", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": '' + } + + except Exception as e: + config_id = config_data.get("config_id", "unknown") + api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") + return { + "status": "错误", + "message": f"启动反思失败: {str(e)}", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig: + """Create reflective configuration objects from configuration data""" + + reflexion_range_value = config_data.get("reflexion_range") + if reflexion_range_value is None or reflexion_range_value == "": + reflexion_range_value = "partial" + reflexion_range = ReflectionRange(reflexion_range_value) + + baseline_value = config_data.get("baseline") + if baseline_value is None or baseline_value == "": + baseline_value = "TIME" + baseline = ReflectionBaseline(baseline_value) + + # iteration_period = + iteration_period = config_data.get("iteration_period", 24) + if isinstance(iteration_period, str): + try: + iteration_period = int(iteration_period) + except (ValueError, TypeError): + iteration_period = 24 # 默认24小时 + + return ReflectionConfig( + enabled=config_data.get("enable_self_reflexion", False), + iteration_period=str(iteration_period), # ReflectionConfig期望字符串 + reflexion_range=reflexion_range, + baseline=baseline, + memory_verify=config_data.get("memory_verify", False), + quality_assessment=config_data.get("quality_assessment", False), + model_id=config_data.get("reflection_model_id", "") + ) + + async def _execute_reflection_engine( + self, + reflection_config: ReflectionConfig, + user_id: str + ) -> Dict[str, Any]: + """Execute Reflection Engine""" + try: + # 创建Neo4j连接器 + connector = Neo4jConnector() + + # 创建反思引擎 + engine = ReflectionEngine( + config=reflection_config, + neo4j_connector=connector, + llm_client=reflection_config.model_id + ) + + # 执行反思 + reflection_result = await engine.execute_reflection(user_id) + + return { + "success": reflection_result.success, + "message": reflection_result.message, + "conflicts_found": reflection_result.conflicts_found, + "conflicts_resolved": reflection_result.conflicts_resolved, + "memories_updated": reflection_result.memories_updated, + "execution_time": reflection_result.execution_time, + "details": reflection_result.details + } + + except Exception as e: + api_logger.error(f"反思引擎执行失败: {str(e)}") + return { + "success": False, + "message": f"反思引擎执行失败: {str(e)}", + "conflicts_found": 0, + "conflicts_resolved": 0, + "memories_updated": 0, + "execution_time": 0.0 + } + + +class Memory_Reflection_Service: + """Memory Reflection Service - Used for calling the/reflection interface""" + + def __init__(self, db: Session): + self.db = db + self.reflection_service = MemoryReflectionService(db) + + async def start_reflection(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + """ + Activate the reflection function + + Args: + config_data: 配置数据,格式如下: + { + "config_id": 26, + "enable_self_reflexion": true, + "iteration_period": "6", + "reflexion_range": "partial", + "baseline": "TIME", + "reflection_model_id": "ea405fa6-c387-4d78-80ab-826d692301b3", + "memory_verify": true, + "quality_assessment": false, + "user_id": null + } + end_user_id: end_user_id,example "12a8b235-6eb1-4481-a53c-b77933b5c949" + + Returns: + """ + api_logger.info(f"Memory_Reflection_Service启动反思,config_id: {config_data.get('config_id')}, end_user_id: {end_user_id}") + + # 调用核心反思服务 + result = await self.reflection_service.start_reflection_from_data(config_data, end_user_id) + + return result \ No newline at end of file diff --git a/api/app/tasks.py b/api/app/tasks.py index 2d461cd3..39758275 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -295,26 +295,6 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage } -def reflection_engine() -> None: - """Empty function placeholder for timed background reflection. - - Intentionally left blank; replace with real reflection logic later. - """ - from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - import asyncio - - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) - - -@celery_app.task(name="app.core.memory.agent.reflection.timer") -def reflection_timer_task() -> None: - """Periodic Celery task that invokes reflection_engine. - - Raises an exception on failure. - """ - reflection_engine() - @celery_app.task(name="app.core.memory.agent.health.check_read_service") def check_read_service_task() -> Dict[str, str]: @@ -464,4 +444,147 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "error": str(e), "workspace_id": workspace_id, "elapsed_time": elapsed_time, + } + + +@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) +def workspace_reflection_task(self) -> Dict[str, Any]: + """定时任务:每30秒运行工作空间反思功能 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService + from app.models.workspace_model import Workspace + from app.core.logging_config import get_api_logger + + api_logger = get_api_logger() + db = next(get_db()) + + try: + # 获取所有工作空间 + workspaces = db.query(Workspace).all() + + if not workspaces: + return { + "status": "SUCCESS", + "message": "没有找到工作空间", + "workspace_count": 0, + "reflection_results": [] + } + + all_reflection_results = [] + + # 遍历每个工作空间 + for workspace in workspaces: + workspace_id = workspace.id + api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + + try: + reflection_service = MemoryReflectionService(db) + + # 使用服务类处理复杂查询逻辑 + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(str(workspace_id)) + + workspace_reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + workspace_reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "reflection_count": len(workspace_reflection_results), + "reflection_results": workspace_reflection_results + }) + + api_logger.info( + f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") + + except Exception as e: + api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "reflection_count": 0, + "reflection_results": [] + }) + + total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) + + return { + "status": "SUCCESS", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", + "workspace_count": len(workspaces), + "total_reflections": total_reflections, + "workspace_results": all_reflection_results + } + + except Exception as e: + api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "workspace_count": 0, + "reflection_results": [] + } + finally: + db.close() + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id } \ No newline at end of file diff --git a/api/check_code.py b/api/check_code.py new file mode 100755 index 00000000..e4634d91 --- /dev/null +++ b/api/check_code.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +代码质量检查脚本 +自动检查代码中的导入错误、未使用变量、语法问题等 + +用法: + python check_code.py # 检查整个 app/ 目录 + python check_code.py file1.py file2.py # 检查指定文件 +""" + +import subprocess +import sys +from pathlib import Path + + +def run_command(cmd: list[str], description: str) -> tuple[bool, str]: + """运行命令并返回结果""" + print(f"\n{'=' * 60}") + print(f"🔍 {description}") + print(f"{'=' * 60}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + + output = result.stdout + result.stderr + success = result.returncode == 0 + + if success: + print(f"✅ {description} - 通过") + else: + print(f"❌ {description} - 发现问题") + if output: + print(output[:2000]) # 只显示前2000字符 + + return success, output + + except Exception as e: + print(f"❌ 执行失败: {e}") + return False, str(e) + + +def main(): + """主函数""" + # 获取命令行参数中的文件列表 + target_files = sys.argv[1:] if len(sys.argv) > 1 else None + + if target_files: + # 检查指定文件 + print(f"🚀 开始代码质量检查 (指定文件: {len(target_files)} 个)...") + target_paths = target_files + ruff_target = target_files + py_compile_files = [f for f in target_files if f.endswith('.py')] + else: + # 检查整个 app/ 目录 + print("🚀 开始代码质量检查 (整个 app/ 目录)...") + target_paths = ["app/"] + ruff_target = ["app/"] + py_compile_files = list(Path("app").rglob("*.py")) + + checks = [ + { + "cmd": ["ruff", "check"] + ruff_target + ["--output-format=concise"], + "description": "Ruff 代码检查 (导入、语法、风格)", + "auto_fix": ["ruff", "check"] + ruff_target + ["--fix", "--unsafe-fixes"], + }, + { + "cmd": ["python", "-m", "py_compile"] + [str(f) for f in py_compile_files], + "description": "Python 语法检查", + "auto_fix": None, + }, + ] + + results = [] + for check in checks: + success, output = run_command(check["cmd"], check["description"]) + results.append( + {"name": check["description"], "success": success, "output": output, "auto_fix": check.get("auto_fix")} + ) + + # 汇总报告 + print(f"\n{'=' * 60}") + print("📊 检查汇总") + print(f"{'=' * 60}") + + all_passed = True + for result in results: + status = "✅ 通过" if result["success"] else "❌ 失败" + print(f"{status} - {result['name']}") + if not result["success"]: + all_passed = False + if result["auto_fix"]: + print(f" 💡 可以运行自动修复: {' '.join(result['auto_fix'])}") + + if all_passed: + print("\n🎉 所有检查通过!") + return 0 + else: + print("\n⚠️ 发现问题,请查看上面的详细信息") + print("\n💡 快速修复命令:") + if target_files: + print(f" ruff check {' '.join(target_files)} --fix --unsafe-fixes") + else: + print(" ruff check app/ --fix --unsafe-fixes") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) From 2ab0335f880fb7740ce2869801d3fb00f4da2491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Fri, 19 Dec 2025 09:40:40 +0000 Subject: [PATCH 06/24] Merge #18 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 反思优化 * fix/memory_reflection: (28 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection Signed-off-by: aliyun8644380055 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/18 --- .../memory_reflection_controller.py | 107 +++++++++++++++--- api/app/schemas/memory_reflection_schemas.py | 4 +- 2 files changed, 94 insertions(+), 17 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 759c25c5..bd9e0e09 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -16,7 +16,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService from app.schemas.memory_reflection_schemas import Memory_Reflection - +from app.services.model_service import ModelConfigService load_dotenv() api_logger = get_api_logger() @@ -47,7 +47,7 @@ async def save_reflection_config( api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") update_params = { - "enable_self_reflexion": request.reflectionenabled, + "enable_self_reflexion": request.reflection_enabled, "iteration_period": request.reflection_period_in_hours, "reflexion_range": request.reflexion_range, "baseline": request.baseline, @@ -115,7 +115,7 @@ async def save_reflection_config( @router.post("/reflection") async def start_workspace_reflection( - request: dict, + config_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -171,30 +171,109 @@ async def start_workspace_reflection( detail=f"启动workspace反思失败: {str(e)}" ) -@router.post("/reflection/run") + +@router.get("/reflection/configs") +async def start_reflection_configs( + config_id: int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """通过config_id查询data_config表中的反思配置信息""" + + try: + api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") + + # 使用DataConfigRepository查询反思配置 + select_query, select_params = DataConfigRepository.build_select_reflection(config_id) + result = db.execute(text(select_query), select_params).fetchone() + + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到config_id为 {config_id} 的配置" + ) + + # 构建返回数据 + reflection_config = { + "config_id": result.config_id, + "enable_self_reflexion": result.enable_self_reflexion, + "iteration_period": result.iteration_period, + "reflexion_range": result.reflexion_range, + "baseline": result.baseline, + "reflection_model_id": result.reflection_model_id, + "memory_verify": result.memory_verify, + "quality_assessment": result.quality_assessment, + "user_id": result.user_id + } + + api_logger.info(f"成功查询反思配置,config_id: {config_id}") + + return { + "status": "成功", + "message": "反思配置查询成功", + "data": reflection_config + } + + except HTTPException: + # 重新抛出HTTP异常 + raise + except Exception as e: + api_logger.error(f"查询反思配置失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"查询反思配置失败: {str(e)}" + ) + +@router.get("/reflection/run") async def reflection_run( - reflection: Memory_Reflection, + config_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """Activate the reflection function for all matching applications in the workspace""" + + api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") + + # 使用DataConfigRepository查询反思配置 + select_query, select_params = DataConfigRepository.build_select_reflection(config_id) + result = db.execute(text(select_query), select_params).fetchone() + + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到config_id为 {config_id} 的配置" + ) + + api_logger.info(f"成功查询反思配置,config_id: {config_id}") + + # 验证模型ID是否存在 + model_id = result.reflection_model_id + if model_id: + try: + ModelConfigService.get_model_by_id(db=db, model_id=model_id) + api_logger.info(f"模型ID验证成功: {model_id}") + except Exception as e: + api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}") + # 可以设置为None,让反思引擎使用默认模型 + model_id = None + config = ReflectionConfig( - enabled=reflection.reflectionenabled, - iteration_period=reflection.reflection_period_in_hours, - reflexion_range=reflection.reflexion_range, - baseline=reflection.baseline, + enabled=result.enable_self_reflexion, + iteration_period=result.iteration_period, + reflexion_range=result.reflexion_range, + baseline=result.baseline, output_example='', - memory_verify=reflection.memory_verify, - quality_assessment=reflection.quality_assessment, + memory_verify=result.memory_verify, + quality_assessment=result.quality_assessment, violation_handling_strategy="block", - model_id=reflection.reflection_model_id + model_id=model_id ) connector = Neo4jConnector() engine = ReflectionEngine( config=config, neo4j_connector=connector, - llm_client=reflection.reflection_model_id # 传入 model_id + llm_client=model_id # 传入验证后的 model_id ) result=await (engine.reflection_run()) - return result + return result \ No newline at end of file diff --git a/api/app/schemas/memory_reflection_schemas.py b/api/app/schemas/memory_reflection_schemas.py index 9eb11c6c..ada92cf2 100644 --- a/api/app/schemas/memory_reflection_schemas.py +++ b/api/app/schemas/memory_reflection_schemas.py @@ -8,11 +8,9 @@ class OptimizationStrategy(str, Enum): SPEED_FIRST = "speed_first" ACCURACY_FIRST = "accuracy_first" BALANCED = "balanced" - - class Memory_Reflection(BaseModel): config_id: Optional[int] = None - reflectionenabled: bool + reflection_enabled: bool reflection_period_in_hours: str reflexion_range: str baseline: str From 8c73aa60b90fee4ba0c0bba9a019903c7f121540 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 19 Dec 2025 18:06:49 +0800 Subject: [PATCH 07/24] [add] migration script --- api/app/models/workspace_model.py | 2 +- .../versions/f96a53af914c_202512191805.py | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 api/migrations/versions/f96a53af914c_202512191805.py diff --git a/api/app/models/workspace_model.py b/api/app/models/workspace_model.py index abb5adeb..4d42ed32 100644 --- a/api/app/models/workspace_model.py +++ b/api/app/models/workspace_model.py @@ -1,7 +1,7 @@ import datetime from enum import StrEnum import uuid -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean +from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base diff --git a/api/migrations/versions/f96a53af914c_202512191805.py b/api/migrations/versions/f96a53af914c_202512191805.py new file mode 100644 index 00000000..9c3d34b5 --- /dev/null +++ b/api/migrations/versions/f96a53af914c_202512191805.py @@ -0,0 +1,36 @@ +"""202512191805 + +Revision ID: f96a53af914c +Revises: 87a6537b4074 +Create Date: 2025-12-19 18:05:14.964454 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f96a53af914c' +down_revision: Union[str, None] = '87a6537b4074' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('data_config', sa.Column('reflection_model_id', sa.String(), nullable=True, comment='反思模型ID')) + op.add_column('data_config', sa.Column('memory_verify', sa.Boolean(), nullable=True, comment='记忆验证')) + op.add_column('data_config', sa.Column('quality_assessment', sa.Boolean(), nullable=True, comment='质量评估')) + op.add_column('end_users', sa.Column('reflection_time', sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('end_users', 'reflection_time') + op.drop_column('data_config', 'quality_assessment') + op.drop_column('data_config', 'memory_verify') + op.drop_column('data_config', 'reflection_model_id') + # ### end Alembic commands ### From 15fac38e30d1fb60273fcfc3b258d76715b36db8 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Thu, 18 Dec 2025 14:50:10 +0800 Subject: [PATCH 08/24] fix(workflow): fix run_workflow streaming issues Resolve exceptions during run_workflow streaming and define proper status codes for error cases. --- api/app/controllers/workflow_controller.py | 2 +- api/app/services/workflow_service.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index 9ccfa858..091846f6 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -473,7 +473,7 @@ async def run_workflow( async def event_generator(): """生成 SSE 事件""" try: - async for event in service.run_workflow( + async for event in await service.run_workflow( app_id=app_id, input_data=input_data, triggered_by=current_user.id, diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index f0b71824..fbf09505 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -5,7 +5,7 @@ import json import logging import uuid import datetime -from typing import Any, Annotated +from typing import Any, Annotated, AsyncGenerator from sqlalchemy.orm import Session from fastapi import Depends @@ -81,7 +81,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -140,7 +140,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -166,7 +166,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -245,7 +245,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -359,7 +359,7 @@ class WorkflowService: execution = self.get_execution(execution_id) if not execution: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"执行记录不存在: execution_id={execution_id}" ) @@ -640,7 +640,7 @@ class WorkflowService: triggered_by: uuid.UUID, conversation_id: uuid.UUID | None = None, stream: bool = False - ): + ) -> AsyncGenerator | dict: """运行工作流 Args: @@ -660,7 +660,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -687,7 +687,7 @@ class WorkflowService: app = self.db.query(App).filter(App.id == app_id).first() if not app: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"应用不存在: app_id={app_id}" ) @@ -750,7 +750,7 @@ class WorkflowService: error_message=str(e) ) raise BusinessException( - error_code=BizCode.INTERNAL_ERROR, + code=BizCode.INTERNAL_ERROR, message=f"工作流执行失败: {str(e)}" ) From 5cd46e441e36dad8a6b0313de51e39a6a78dca35 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:08:54 +0800 Subject: [PATCH 09/24] fix(prompt-optimizer): switch to built-in system prompt - Replace the system prompt of the prompt optimization model with a built-in prompt. - Remove system prompt entries from the database. - Remove the API endpoint for managing system prompt configuration. --- .../prompt_optimizer_controller.py | 34 +--- api/app/models/__init__.py | 3 +- api/app/models/prompt_optimizer_model.py | 43 ----- .../prompt_optimizer_repository.py | 105 ----------- api/app/services/prompt_optimizer_service.py | 170 +++++++++--------- 5 files changed, 86 insertions(+), 269 deletions(-) diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index d647f0c0..d73ea0df 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -117,7 +117,7 @@ async def get_prompt_opt( session_id=session_id, user_id=current_user.id, current_prompt=data.current_prompt, - message=data.message + user_require=data.message ) service.create_message( tenant_id=current_user.tenant_id, @@ -136,35 +136,3 @@ async def get_prompt_opt( return success(data=result_schema) -@router.put( - "/model", - summary="Create or update prompt model config", - response_model=ApiResponse -) -def set_system_prompt( - data: PromptOptModelSet = ..., - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """ - Create or update a system prompt model configuration for the tenant. - - Args: - data (PromptOptModelSet): Model configuration data including model ID, - system prompt, and optional configuration ID - db (Session): Database session - current_user: Current user information - - Returns: - UUID: The ID of the created or updated model configuration. - """ - if data.id is None: - data.id = uuid.uuid4() - - model_config = PromptOptimizerService(db).create_update_model_config( - current_user.tenant_id, - data.id, - data.system_prompt - ) - return success(data=model_config.id) - diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 09c88ba3..01dad24e 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -20,7 +20,7 @@ from .data_config_model import DataConfig from .multi_agent_model import MultiAgentConfig, AgentInvocation from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .retrieval_info import RetrievalInfo -from .prompt_optimizer_model import PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory +from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory from .tool_model import ( ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus @@ -60,7 +60,6 @@ __all__ = [ "WorkflowExecution", "WorkflowNodeExecution", "RetrievalInfo", - "PromptOptimizerModelConfig", "PromptOptimizerSession", "PromptOptimizerSessionHistory", "RetrievalInfo", diff --git a/api/app/models/prompt_optimizer_model.py b/api/app/models/prompt_optimizer_model.py index 5191fc2e..39845ee7 100644 --- a/api/app/models/prompt_optimizer_model.py +++ b/api/app/models/prompt_optimizer_model.py @@ -27,49 +27,6 @@ class RoleType(StrEnum): ASSISTANT = "assistant" -class PromptOptimizerModelConfig(Base): - """ - Prompt Optimization Model Configuration. - - This table stores system-level prompt configurations for each tenant. - The configuration defines the base system prompt used during prompt - optimization sessions and serves as a foundational instruction set - for the optimization process. - - Each tenant may have one or more model configurations depending on - business requirements. - - Table Name: - prompt_model_config - - Columns: - id (UUID): - Primary key. Unique identifier for the prompt model configuration. - tenant_id (UUID): - Foreign key referencing `tenants.id`. - Identifies the tenant that owns this configuration. - system_prompt (Text): - The system-level prompt used to guide prompt optimization logic. - created_at (DateTime): - Timestamp indicating when the configuration was created. - updated_at (DateTime): - Timestamp indicating the last update time of the configuration. - - Usage: - - Loaded when initializing a prompt optimization session - - Acts as the root system instruction for all subsequent prompts - """ - __tablename__ = "prompt_model_config" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID") - # model_id = Column(UUID(as_uuid=True), nullable=False, comment="Model ID") - system_prompt = Column(Text, nullable=False, comment="System Prompt") - - created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="Update Time") - - class PromptOptimizerSession(Base): """ Prompt Optimization Session Registry. diff --git a/api/app/repositories/prompt_optimizer_repository.py b/api/app/repositories/prompt_optimizer_repository.py index ecb2af98..ba65257a 100644 --- a/api/app/repositories/prompt_optimizer_repository.py +++ b/api/app/repositories/prompt_optimizer_repository.py @@ -1,120 +1,15 @@ import uuid -from typing import Optional from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger from app.models.prompt_optimizer_model import ( - PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType ) db_logger = get_db_logger() -class PromptOptimizerModelConfigRepository: - """Repository for managing prompt optimizer model configurations.""" - - def __init__(self, db: Session): - self.db = db - - def get_by_tenant_id(self, tenant_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]: - """ - Retrieve the prompt optimizer model configuration for a specific tenant. - - Args: - tenant_id (uuid.UUID): The unique identifier of the tenant. - - Returns: - Optional[PromptOptimizerModelConfig]: The model configuration if found, else None. - """ - db_logger.debug(f"Get prompt optimization model configuration: tenant_id={tenant_id}") - - try: - config = self.db.query(PromptOptimizerModelConfig).filter( - PromptOptimizerModelConfig.tenant_id == tenant_id, - # PromptOptimizerModelConfig.model_id == model_id - ).first() - if config: - db_logger.debug(f"Prompt optimization model configuration found: (ID: {config.id})") - else: - db_logger.debug(f"Prompt optimization model configuration not found: tenant_id={tenant_id}") - return config - except Exception as e: - db_logger.error( - f"Error retrieving prompt optimization model configuration: tenant_id={tenant_id} - {str(e)}") - raise - - def get_by_config_id(self, tenant_id: uuid.UUID, config_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]: - """ - Retrieve a specific prompt optimizer model configuration by config ID and tenant ID. - - Args: - tenant_id (uuid.UUID): The unique identifier of the tenant. - config_id (uuid.UUID): The unique identifier of the model configuration. - - Returns: - Optional[PromptOptimizerModelConfig]: The model configuration if found, else None. - """ - db_logger.debug(f"Get prompt optimization model configuration: config_id={config_id}, tenant_id={tenant_id}") - try: - model = self.db.query(PromptOptimizerModelConfig).filter( - PromptOptimizerModelConfig.tenant_id == tenant_id, - PromptOptimizerModelConfig.id == config_id - ).first() - if model: - db_logger.debug(f"Prompt optimization model configuration found: (ID: {model.id})") - else: - db_logger.debug(f"Prompt optimization model configuration not found: config_id={config_id}") - return model - except Exception as e: - db_logger.error( - f"Error retrieving prompt optimization model configuration: model_id={config_id} - {str(e)}") - raise - - def create_or_update( - self, - config_id: uuid.UUID, - tenant_id: uuid.UUID, - system_prompt: str, - ) -> Optional[PromptOptimizerModelConfig]: - """ - Create a new or update an existing prompt optimizer model configuration. - - If a configuration with the given config_id exists, it updates its system_prompt. - Otherwise, it creates a new configuration record. - - Args: - config_id (uuid.UUID): The unique identifier for the configuration. - tenant_id (uuid.UUID): The tenant's unique identifier. - system_prompt (str): The system prompt content for prompt optimization. - - Returns: - Optional[PromptOptimizerModelConfig]: The created or updated model configuration. - """ - db_logger.debug(f"Create/Update prompt optimization model configuration: tenant_id={tenant_id}") - existing_config = self.get_by_config_id(tenant_id, config_id) - - if existing_config: - existing_config.system_prompt = system_prompt - self.db.commit() - self.db.refresh(existing_config) - db_logger.debug(f"Prompt optimization model configuration update: ID:{config_id}") - return existing_config - else: - config = PromptOptimizerModelConfig( - id=config_id, - # model_id=model_id, - tenant_id=tenant_id, - system_prompt=system_prompt - ) - self.db.add(config) - self.db.commit() - self.db.refresh(config) - db_logger.debug(f"Prompt optimization model configuration created: ID:{config.id}") - return config - - class PromptOptimizerSessionRepository: """Repository for managing prompt optimization sessions and session history.""" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 0cdaabf5..5355474f 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -1,4 +1,3 @@ -import json import re import uuid @@ -12,13 +11,11 @@ from app.core.models import RedBearModelConfig from app.core.models.llm import RedBearLLM from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory from app.models.prompt_optimizer_model import ( - PromptOptimizerModelConfig, PromptOptimizerSession, RoleType ) from app.repositories.model_repository import ModelConfigRepository from app.repositories.prompt_optimizer_repository import ( - PromptOptimizerModelConfigRepository, PromptOptimizerSessionRepository ) from app.schemas.prompt_optimizer_schema import OptimizePromptResult @@ -34,32 +31,24 @@ class PromptOptimizerService: self, tenant_id: uuid.UUID, model_id: uuid.UUID - ) -> tuple[PromptOptimizerModelConfig, ModelConfig]: + ) -> ModelConfig: """ - Retrieve the prompt optimizer model configuration and model configuration. + Retrieve the model configuration for a specific tenant. - This method retrieves the prompt optimizer model configuration associated - with the specified model ID and tenant. It also fetches the corresponding - model configuration. + This method fetches the model configuration associated with the given + tenant_id and model_id. If no configuration is found, a BusinessException + is raised. Args: tenant_id (uuid.UUID): The unique identifier of the tenant. - model_id (uuid.UUID): The unique identifier of the prompt optimization model. + model_id (uuid.UUID): The unique identifier of the model. Returns: - tuple[PromptOptimzerModelConfig, ModelConfig]: - A tuple containing the prompt optimizer model configuration - and the corresponding model configuration. + ModelConfig: The corresponding model configuration object. Raises: - BusinessException: If the prompt optimizer model configuration does not exist. BusinessException: If the model configuration does not exist. """ - prompt_config = PromptOptimizerModelConfigRepository(self.db).get_by_tenant_id( - tenant_id - ) - if not prompt_config: - raise BusinessException("提示词模型配置不存在", BizCode.NOT_FOUND) model = ModelConfigRepository.get_by_id( self.db, model_id, tenant_id=tenant_id @@ -67,35 +56,7 @@ class PromptOptimizerService: if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - return prompt_config, model - - def create_update_model_config( - self, - tenant_id: uuid.UUID, - config_id: uuid.UUID, - system_prompt: str, - ) -> PromptOptimizerModelConfig: - """ - Create or update a prompt optimizer model configuration. - - This method creates a new prompt optimizer model configuration or updates - an existing one identified by the given configuration ID. The configuration - defines the system prompt used for prompt optimization. - - Args: - tenant_id (uuid.UUID): The unique identifier of the tenant. - config_id (uuid.UUID): The unique identifier of the configuration to create or update. - system_prompt (str): The system prompt content used for prompt optimization. - - Returns: - PromptOptimzerModelConfig: The created or updated prompt optimizer model configuration. - """ - prompt_config = PromptOptimizerModelConfigRepository(self.db).create_or_update( - config_id=config_id, - tenant_id=tenant_id, - system_prompt=system_prompt, - ) - return prompt_config + return model def create_session( self, @@ -159,37 +120,46 @@ class PromptOptimizerService: session_id: uuid.UUID, user_id: uuid.UUID, current_prompt: str, - message: str + user_require: str ) -> OptimizePromptResult: """ - Optimize a prompt using a prompt optimizer LLM. + Optimize a user-provided prompt using a configured prompt optimizer LLM. - This method uses a configured prompt optimizer model to refine an existing - prompt based on the user's requirements. The optimized prompt is generated - according to predefined system rules, including Jinja2 variable syntax and - a strict JSON output format. + This method refines the original prompt according to the user's requirements, + generating an optimized version that is directly usable by AI tools. The + optimization process follows strict rules, including: + - Wrapping user-inserted variables in double curly braces {{}}. + - Adhering to Jinja2 variable syntax if applicable. + - Ensuring a clear logic flow, explicit instructions, and strong executability. + - Producing output in a strict JSON format. + + Steps performed: + 1. Retrieve the model configuration for the given tenant and model. + 2. Fetch the session message history for context. + 3. Instantiate the LLM with the appropriate API key and model configuration. + 4. Build system messages outlining optimization rules. + 5. Format the user's original prompt and requirements as a user message. + 6. Send messages to the LLM to generate the optimized prompt. + 7. Generate a concise description summarizing the changes made during optimization. Args: - tenant_id (uuid.UUID): The unique identifier of the tenant. - model_id (uuid.UUID): The unique identifier of the prompt optimizer model. - session_id (uuid.UUID): The unique identifier of the prompt optimization session. - user_id (uuid.UUID): The unique identifier of the user associated with the session. - current_prompt (str): The original prompt to be optimized. - message (str): The user's requirements or modification instructions. + tenant_id (uuid.UUID): Tenant identifier. + model_id (uuid.UUID): Prompt optimizer model identifier. + session_id (uuid.UUID): Prompt optimization session identifier. + user_id (uuid.UUID): Identifier of the user associated with the session. + current_prompt (str): Original prompt to optimize. + user_require (str): User's requirements or instructions for optimization. Returns: - dict: A dictionary containing the optimized prompt and the description - of changes, in the following format: - { - "prompt": "", - "desc": "" - } + OptimizePromptResult: An object containing: + - prompt: The optimized prompt string. + - desc: A short description summarizing the changes. Raises: - BusinessException: If the model response cannot be parsed as valid JSON + BusinessException: If the LLM response cannot be parsed as valid JSON or does not conform to the expected output format. """ - prompt_config, model_config = self.get_model_config(tenant_id, model_id) + model_config = self.get_model_config(tenant_id, model_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) # Create LLM instance @@ -204,36 +174,65 @@ class PromptOptimizerService: # build message messages = [ # init system_prompt - (RoleType.SYSTEM.value, prompt_config.system_prompt), + ( + RoleType.SYSTEM.value, + "Your task is to optimize the original prompt provided by the user so that it can be directly used by AI tools," + "and the variables that the user needs to insert must be wrapped in {{}}. " + "The optimized prompt should align with the optimization direction specified by the user (if any) and ensure clear logic, explicit instructions, and strong executability. " + "Please follow these rules when optimizing: " + '1. Ensure variables are wrapped in {{}}, e.g., optimize "Please enter your question" to "Please enter your {{question}}"' + "2. Instructions must be specific and operable, avoiding vague expressions" + "3. If the original prompt lacks key elements (such as output format requirements), supplement them completely " + "4. Keep the language concise and avoid redundancy " + "5. If the user does not specify an optimization direction, the default optimization is to make the prompt structurally clear and with explicit instructions" + "Please directly output the optimized prompt without additional explanations. The optimized prompt should be directly usable with correct variable positions." + ), # base model limit (RoleType.SYSTEM.value, "Optimization Rules:\n" "1. Fully adjust the prompt content according to the user's requirements.\n" - "2. When the user requests the insertion of variables, you must use Jinja2 syntax {{variable_name}} " - "(the variable name should be determined based on the user's requirement).\n" + "When variables are required, use double curly braces {{variable_name}} as placeholders." + "Variable names must be derived from the user's requirements.\n" "3. Keep the prompt logic clear and instructions explicit.\n" - "4. Ensure that the modified prompt can be directly used.\n\n" - "Output Requirements:\n" - "Provide the result in JSON format, containing exactly two fields:\n" - " - prompt: The modified prompt (string).\n" - " - desc: A response addressing the user's optimization request (string).") + "4. Ensure that the modified prompt can be directly used.\n\n") ] messages.extend(session_history[:-1]) # last message is current message user_message_template = ChatPromptTemplate.from_messages([ - (RoleType.USER.value, "[current_prompt]\n{current_prompt}\n[user_require]\n{message}") + (RoleType.USER.value, "[original_prompt]\n{current_prompt}\n[user_require]\n{user_require}") ]) - formatted_user_message = user_message_template.format(current_prompt=current_prompt, message=message) + formatted_user_message = user_message_template.format(current_prompt=current_prompt, user_require=user_require) messages.extend([(RoleType.USER.value, formatted_user_message)]) logger.info(f"Prompt optimization message: {messages}") - result = await llm.ainvoke(messages) - try: - data_dict = json.loads(result.content) - model_resp = OptimizePromptResult.model_validate(data_dict) - except Exception as e: - logger.error(f"Failed to parse model reponse to json - Error: {str(e)}", exc_info=True) - raise BusinessException("Failed to parse model response", BizCode.PARSER_NOT_SUPPORTED) - return model_resp + optim_prompt = await llm.ainvoke(messages) + optim_desc = [ + ( + RoleType.SYSTEM.value, + "You are a prompt optimization assistant.\n" + "Compare the original prompt, the user's requirements, " + "and the optimized prompt.\n" + "Summarize the changes made during optimization.\n\n" + "Rules:\n" + "1. Output must be a single short sentence.\n" + "2. Be concise and factual.\n" + "3. Do not explain the prompts themselves.\n" + "4. Do not include any extra text." + ), + ( + "[Original Prompt]\n" + f"{current_prompt}\n\n" + "[User Requirements]\n" + f"{user_require}\n\n" + "[Optimized Prompt]\n" + f"{optim_prompt.content}" + ) + ] + optim_desc = await llm.ainvoke(optim_desc) + + return OptimizePromptResult( + prompt=optim_prompt.content, + desc=optim_desc.content + ) @staticmethod def parser_prompt_variables(prompt: str): @@ -277,4 +276,3 @@ class PromptOptimizerService: content=content ) return message - From 01ac36195aefc1c7b3ccef2c233bb5878b68640f Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:19:18 +0800 Subject: [PATCH 10/24] feat(workflow): add conditional branch (If-Else) node - Introduce a new conditional branch node for workflows. - Supports multiple case branches with logical operators (AND/OR). - Enables workflow routing based on evaluated conditions. --- api/app/core/workflow/executor.py | 81 +++++---- .../core/workflow/nodes/if_else/__init__.py | 5 + api/app/core/workflow/nodes/if_else/config.py | 122 +++++++++++++ api/app/core/workflow/nodes/if_else/node.py | 168 ++++++++++++++++++ 4 files changed, 343 insertions(+), 33 deletions(-) create mode 100644 api/app/core/workflow/nodes/if_else/__init__.py create mode 100644 api/app/core/workflow/nodes/if_else/config.py create mode 100644 api/app/core/workflow/nodes/if_else/node.py diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 9cf711db..3710e4ed 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -13,8 +13,9 @@ from langchain_core.messages import HumanMessage from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.nodes.enums import NodeType from app.core.tools.registry import ToolRegistry from app.core.tools.executor import ToolExecutor from app.core.tools.langchain_adapter import LangchainAdapter @@ -30,11 +31,11 @@ class WorkflowExecutor: """ def __init__( - self, - workflow_config: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + self, + workflow_config: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ): """初始化执行器 @@ -95,8 +96,6 @@ class WorkflowExecutor: "error_node": None } - - def build_graph(self) -> CompiledStateGraph: """构建 LangGraph @@ -117,19 +116,36 @@ class WorkflowExecutor: node_id = node.get("id") # 记录 start 和 end 节点 ID - if node_type == "start": + if node_type == NodeType.START: start_node_id = node_id - elif node_type == "end": + elif node_type == NodeType.END: end_node_ids.append(node_id) # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE]: + # Build ordered boolean expression strings for each branch. + # These expressions will be attached to outgoing edges as + # LangGraph conditional routing rules. + expressions = node_instance.build_conditional_edge_expressions() + + # Collect all outgoing edges from the current node. + # The order of edges must match the order of generated expressions. + related_edge = [edge for edge in self.edges if edge.get("source") == node_id] + + # Attach each condition expression to the corresponding edge + # based on branch priority + for idx in range(len(expressions)): + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + if node_instance: # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 def make_node_func(inst): async def node_func(state: WorkflowState): return await inst.run(state) + return node_func workflow.add_node(node_id, make_node_func(node_instance)) @@ -170,14 +186,14 @@ class WorkflowExecutor: def router(state: WorkflowState, cond=condition, tgt=target): """条件路由函数""" if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } ): return tgt return END # 条件不满足,结束 @@ -201,8 +217,8 @@ class WorkflowExecutor: return graph async def execute( - self, - input_data: dict[str, Any] + self, + input_data: dict[str, Any] ) -> dict[str, Any]: """执行工作流(非流式) @@ -276,8 +292,8 @@ class WorkflowExecutor: } async def execute_stream( - self, - input_data: dict[str, Any] + self, + input_data: dict[str, Any] ): """执行工作流(流式) @@ -331,7 +347,6 @@ class WorkflowExecutor: "token_usage": None } - def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: """从节点输出中提取最终输出 @@ -391,11 +406,11 @@ class WorkflowExecutor: async def execute_workflow( - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ) -> dict[str, Any]: """执行工作流(便捷函数) @@ -419,11 +434,11 @@ async def execute_workflow( async def execute_workflow_stream( - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ): """执行工作流(流式,便捷函数) diff --git a/api/app/core/workflow/nodes/if_else/__init__.py b/api/app/core/workflow/nodes/if_else/__init__.py new file mode 100644 index 00000000..ffdf3b5b --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/__init__.py @@ -0,0 +1,5 @@ +"""Condition Node""" +from app.core.workflow.nodes.if_else.config import IfElseNodeConfig +from app.core.workflow.nodes.if_else.node import IfElseNode + +__all__ = ["IfElseNode", "IfElseNodeConfig"] diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py new file mode 100644 index 00000000..1a9adbbb --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -0,0 +1,122 @@ +"""Condition Configuration""" +from pydantic import Field, BaseModel, field_validator +from enum import StrEnum +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class LogicOperator(StrEnum): + AND = "and" + OR = "or" + + +class ComparisonOpeartor(StrEnum): + EMPTY = "empty" + NOT_EMPTY = "not_empty" + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + START_WITH = "startwith" + END_WITH = "endwith" + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class ConditionDetail(BaseModel): + comparison_operator: ComparisonOpeartor = Field( + ..., + description="Comparison operator used to evaluate the condition" + ) + + left: str = Field( + ..., + description="Value to compare against" + ) + + right: str = Field( + ..., + description="Value to compare with" + ) + + +class ConditionBranchConfig(BaseModel): + """Configuration for a conditional branch""" + + logical_operator: LogicOperator = Field( + default=LogicOperator.AND.value, + description="Logical operator used to combine multiple condition expressions" + ) + + conditions: list[ConditionDetail] = Field( + ..., + description="List of condition expressions within this branch" + ) + + +class IfElseNodeConfig(BaseNodeConfig): + cases: list[ConditionBranchConfig] = Field( + ..., + description="List of branch conditions or expressions" + ) + + @field_validator("cases") + @classmethod + def validate_case_number(cls, v, info): + if len(v) < 1: + raise ValueError("At least one cases are required") + return v + + class Config: + json_schema_extra = { + "examples": [ + { + "cases": [ + # if/CASE1 + { + "logical_operator": "and", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + } + ] + }, + ] + }, + { + "case_number": 3, + "cases": [ + # if/CASE1 + { + "logic": "or", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + } + ] + }, + # elif/CASE2 + { + "logic": "and", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + }, + { + "left": "sys.message", + "comparison_operator": "contains", + "right": "'test'" + } + ] + }, + ] + } + ] + } diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py new file mode 100644 index 00000000..3219edae --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -0,0 +1,168 @@ +import logging +from typing import Any + +from simpleeval import NameNotDefined, InvalidExpression + +from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.if_else import IfElseNodeConfig +from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor + +logger = logging.getLogger(__name__) + + +class ConditionExpressionBuilder: + """ + Build a Python boolean expression string based on a comparison operator. + + This class does not evaluate the expression. + It only generates a valid Python expression string + that can be evaluated later in a workflow context. + """ + + def __init__(self, left: str, operator: ComparisonOpeartor, right: str): + self.left = left + self.operator = operator + self.right = right + + def _empty(self): + return f"{self.left} == ''" + + def _not_empty(self): + return f"{self.left} != ''" + + def _contains(self): + return f"{self.right} in {self.left}" + + def _not_contains(self): + return f"{self.right} not in {self.left}" + + def _startwith(self): + return f'{self.left}.startswith({self.right})' + + def _endwith(self): + return f'{self.left}.endswith({self.right})' + + def _eq(self): + return f"{self.left} == {self.right}" + + def _ne(self): + return f"{self.left} != {self.right}" + + def _lt(self): + return f"{self.left} < {self.right}" + + def _le(self): + return f"{self.left} <= {self.right}" + + def _gt(self): + return f"{self.left} > {self.right}" + + def _ge(self): + return f"{self.left} >= {self.right}" + + def build(self): + match self.operator: + case ComparisonOpeartor.EMPTY: + return self._empty() + case ComparisonOpeartor.NOT_EMPTY: + return self._not_empty() + case ComparisonOpeartor.CONTAINS: + return self._contains() + case ComparisonOpeartor.NOT_CONTAINS: + return self._not_contains() + case ComparisonOpeartor.START_WITH: + return self._startwith() + case ComparisonOpeartor.END_WITH: + return self._endwith() + case ComparisonOpeartor.EQ: + return self._eq() + case ComparisonOpeartor.NE: + return self._ne() + case ComparisonOpeartor.LT: + return self._lt() + case ComparisonOpeartor.LE: + return self._le() + case ComparisonOpeartor.GT: + return self._gt() + case ComparisonOpeartor.GE: + return self._ge() + case _: + raise ValueError(f"Invalid condition: {self.operator}") + + +class IfElseNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = IfElseNodeConfig(**self.config) + + @staticmethod + def _build_condition_expression( + condition: ConditionDetail, + ) -> str: + """ + Build a single boolean condition expression string. + + This method does NOT evaluate the condition. + It only generates a valid Python boolean expression string + (e.g. "x > 10", "'a' in name") that can later be used + in a conditional edge or evaluated by the workflow engine. + + Args: + condition (ConditionDetail): Definition of a single comparison condition. + + Returns: + str: A Python boolean expression string. + """ + return ConditionExpressionBuilder( + left=condition.left, + operator=condition.comparison_operator, + right=condition.right + ).build() + + def build_conditional_edge_expressions(self) -> list[str]: + """ + Build conditional edge expressions for the If-Else node. + + This method does NOT evaluate any condition at runtime. + Instead, it converts each case branch into a Python boolean + expression string, which will later be attached to LangGraph + as conditional edges. + + Each returned expression corresponds to one branch and is + evaluated in order. A fallback 'True' condition is appended + to ensure a default branch when no previous conditions match. + + Returns: + list[str]: A list of Python boolean expression strings, + ordered by branch priority. + """ + branch_index = 0 + conditions = [] + + for case_branch in self.typed_config.cases: + branch_index += 1 + + branch_conditions = [ + self._build_condition_expression(condition) + for condition in case_branch.conditions + ] + if len(branch_conditions) > 1: + combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) + else: + combined_condition = branch_conditions[0] + conditions.append(combined_condition) + + # Default fallback branch + conditions.append("True") + + return conditions + + async def execute(self, state: WorkflowState) -> Any: + """ + """ + expressions = self.build_conditional_edge_expressions() + for i in range(len(expressions)): + logger.info(expressions[i]) + if self._evaluate_condition(expressions[i], state): + return f'CASE{i+1}' + return f'CASE{len(expressions)}' From 647fc27bb5cad5ce13bf36bbd5eac1761e2474b2 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:21:27 +0800 Subject: [PATCH 11/24] perf(types): add Union type declaration for workflow nodes - Introduce a `Nodes` type as a Union of all workflow node classes. - Improves type checking and IDE autocompletion. --- api/app/core/workflow/nodes/enums.py | 21 +++++++++++++++++++++ api/app/core/workflow/nodes/node_factory.py | 10 ++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 9cec19d2..5e586a9c 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,4 +1,14 @@ from enum import StrEnum +from typing import Union + +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.if_else import IfElseNode +from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.end import EndNode + class NodeType(StrEnum): START = "start" @@ -13,3 +23,14 @@ class NodeType(StrEnum): HTTP_REQUEST = "http-request" TOOL = "tool" AGENT = "agent" + + +WorkflowNode = Union[ + BaseNode, + StartNode, + EndNode, + LLMNode, + IfElseNode, + AgentNode, + TransformNode, +] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index f279d13a..e1f32308 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -8,7 +8,8 @@ import logging from typing import Any from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.nodes.enums import NodeType, WorkflowNode +from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.transform import TransformNode @@ -25,16 +26,17 @@ class NodeFactory: """ # 节点类型注册表 - _node_types: dict[str, type[BaseNode]] = { + _node_types: dict[str, type[WorkflowNode]] = { NodeType.START: StartNode, NodeType.END: EndNode, NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, + NodeType.IF_ELSE: IfElseNode } @classmethod - def register_node_type(cls, node_type: str, node_class: type[BaseNode]): + def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]): """注册新的节点类型 Args: @@ -55,7 +57,7 @@ class NodeFactory: cls, node_config: dict[str, Any], workflow_config: dict[str, Any] - ) -> BaseNode | None: + ) -> WorkflowNode | None: """创建节点实例 Args: From aa44b8df71a5ffa7c80849d0631d0b1022fd6ebe Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:23:29 +0800 Subject: [PATCH 12/24] fix(expression-eval): fix variable extraction issue in Jinja2 templates - Resolve the bug where variables inside Jinja2 template expressions were not correctly extracted. - Ensure expressions containing {{ ... }} are parsed reliably. --- api/app/core/workflow/expression_evaluator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py index c8875d79..81ab25dc 100644 --- a/api/app/core/workflow/expression_evaluator.py +++ b/api/app/core/workflow/expression_evaluator.py @@ -5,6 +5,7 @@ """ import logging +import re from typing import Any from simpleeval import simple_eval, NameNotDefined, InvalidExpression @@ -59,9 +60,10 @@ class ExpressionEvaluator: """ # 移除 Jinja2 模板语法的花括号(如果存在) expression = expression.strip() - if expression.startswith("{{") and expression.endswith("}}"): - expression = expression[2:-2].strip() - + # "{{system.message}} == {{ user.messge }}" -> "system.message == user.message" + pattern = r"\{\{\s*(.*?)\s*\}\}" + expression = re.sub(pattern, r"\1", expression).strip() + # 构建命名空间上下文 context = { "var": variables, # 用户变量 From cb6d7b04f960b8532eab807fa31d7ddc66419d14 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:34:01 +0800 Subject: [PATCH 13/24] docs(samples): add config example for If-Else node - Provide a sample configuration for the If-Else workflow node. - Helps users understand how to define conditional branches. --- api/app/core/workflow/nodes/if_else/config.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 1a9adbbb..1eaddc63 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -73,49 +73,43 @@ class IfElseNodeConfig(BaseNodeConfig): "examples": [ { "cases": [ - # if/CASE1 + # CASE1 / IF Branch { "logical_operator": "and", "conditions": [ { - "left": "sys.message", - "comparison_operator": "eq", - "right": "'test'" + { + "left": "node.userinput.message", + "comparison_operator": "eq", + "right": "'123'" + }, + { + "left": "node.userinput.test", + "comparison_operator": "eq", + "right": "True" + } } ] }, - ] - }, - { - "case_number": 3, - "cases": [ - # if/CASE1 + # CASE1 / ELIF Branch { - "logic": "or", + "logical_operator": "or", "conditions": [ { - "left": "sys.message", - "comparison_operator": "eq", - "right": "'test'" + { + "left": "node.userinput.test", + "comparison_operator": "eq", + "right": "False" + }, + { + "left": "node.userinput.message", + "comparison_operator": "contains", + "right": "'123'" + } } ] - }, - # elif/CASE2 - { - "logic": "and", - "conditions": [ - { - "left": "sys.message", - "comparison_operator": "eq", - "right": "'test'" - }, - { - "left": "sys.message", - "comparison_operator": "contains", - "right": "'test'" - } - ] - }, + } + # CASE3 / ELSE Branch ] } ] From debb2f01623ba52198cadee6b2f1f032ccf42823 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:43:47 +0800 Subject: [PATCH 14/24] style(workflow): update condition edge comments for conditional nodes --- api/app/core/workflow/executor.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 3710e4ed..6effaa5b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -125,18 +125,20 @@ class WorkflowExecutor: node_instance = NodeFactory.create_node(node, self.workflow_config) if node_type in [NodeType.IF_ELSE]: - # Build ordered boolean expression strings for each branch. - # These expressions will be attached to outgoing edges as - # LangGraph conditional routing rules. expressions = node_instance.build_conditional_edge_expressions() - # Collect all outgoing edges from the current node. - # The order of edges must match the order of generated expressions. + # Number of branches, usually matches the number of conditional expressions + branch_number = len(expressions) + + # Find all edges whose source is the current node related_edge = [edge for edge in self.edges if edge.get("source") == node_id] - # Attach each condition expression to the corresponding edge - # based on branch priority - for idx in range(len(expressions)): + # Iterate over each branch + for idx in range(branch_number): + # Generate a condition expression for each edge + # Used later to determine which branch to take based on the node's output + # Assumes node output `node..output` matches the edge's label + # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" if node_instance: From 4e0c5ed3c16835acae2ad9bcc5c60bd1a471d6e8 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 15:16:00 +0800 Subject: [PATCH 15/24] style(enums): correct enum class name spelling --- api/app/core/workflow/nodes/if_else/config.py | 4 +-- api/app/core/workflow/nodes/if_else/node.py | 28 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 1eaddc63..0e759569 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -9,7 +9,7 @@ class LogicOperator(StrEnum): OR = "or" -class ComparisonOpeartor(StrEnum): +class ComparisonOperator(StrEnum): EMPTY = "empty" NOT_EMPTY = "not_empty" CONTAINS = "contains" @@ -25,7 +25,7 @@ class ComparisonOpeartor(StrEnum): class ConditionDetail(BaseModel): - comparison_operator: ComparisonOpeartor = Field( + comparison_operator: ComparisonOperator = Field( ..., description="Comparison operator used to evaluate the condition" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 3219edae..fcfbd9ac 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -5,7 +5,7 @@ from simpleeval import NameNotDefined, InvalidExpression from app.core.workflow.nodes import BaseNode, WorkflowState from app.core.workflow.nodes.if_else import IfElseNodeConfig -from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor +from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOperator logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ class ConditionExpressionBuilder: that can be evaluated later in a workflow context. """ - def __init__(self, left: str, operator: ComparisonOpeartor, right: str): + def __init__(self, left: str, operator: ComparisonOperator, right: str): self.left = left self.operator = operator self.right = right @@ -62,29 +62,29 @@ class ConditionExpressionBuilder: def build(self): match self.operator: - case ComparisonOpeartor.EMPTY: + case ComparisonOperator.EMPTY: return self._empty() - case ComparisonOpeartor.NOT_EMPTY: + case ComparisonOperator.NOT_EMPTY: return self._not_empty() - case ComparisonOpeartor.CONTAINS: + case ComparisonOperator.CONTAINS: return self._contains() - case ComparisonOpeartor.NOT_CONTAINS: + case ComparisonOperator.NOT_CONTAINS: return self._not_contains() - case ComparisonOpeartor.START_WITH: + case ComparisonOperator.START_WITH: return self._startwith() - case ComparisonOpeartor.END_WITH: + case ComparisonOperator.END_WITH: return self._endwith() - case ComparisonOpeartor.EQ: + case ComparisonOperator.EQ: return self._eq() - case ComparisonOpeartor.NE: + case ComparisonOperator.NE: return self._ne() - case ComparisonOpeartor.LT: + case ComparisonOperator.LT: return self._lt() - case ComparisonOpeartor.LE: + case ComparisonOperator.LE: return self._le() - case ComparisonOpeartor.GT: + case ComparisonOperator.GT: return self._gt() - case ComparisonOpeartor.GE: + case ComparisonOperator.GE: return self._ge() case _: raise ValueError(f"Invalid condition: {self.operator}") From d12b1e4a51fcddab5aa6fcc982b9bc674e048c68 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 15:43:56 +0800 Subject: [PATCH 16/24] refactor(workflow): unify all enum classes in one file and restructure workflow node type definitions --- api/app/core/workflow/nodes/__init__.py | 13 ++++--- api/app/core/workflow/nodes/enums.py | 36 +++++++++---------- api/app/core/workflow/nodes/if_else/config.py | 31 ++++------------ api/app/core/workflow/nodes/if_else/node.py | 5 ++- api/app/core/workflow/nodes/node_factory.py | 26 +++++++++----- 5 files changed, 52 insertions(+), 59 deletions(-) diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 820c9301..d143c693 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -4,13 +4,14 @@ 提供各种类型的节点实现,用于工作流执行。 """ -from app.core.workflow.nodes.base_node import BaseNode, WorkflowState -from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.agent import AgentNode -from app.core.workflow.nodes.transform import TransformNode -from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode -from app.core.workflow.nodes.node_factory import NodeFactory +from app.core.workflow.nodes.if_else import IfElseNode +from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode +from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.transform import TransformNode __all__ = [ "BaseNode", @@ -18,7 +19,9 @@ __all__ = [ "LLMNode", "AgentNode", "TransformNode", + "IfElseNode", "StartNode", "EndNode", "NodeFactory", + "WorkflowNode" ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 5e586a9c..af5ddbaa 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,13 +1,4 @@ from enum import StrEnum -from typing import Union - -from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.if_else import IfElseNode -from app.core.workflow.nodes.llm import LLMNode -from app.core.workflow.nodes.agent import AgentNode -from app.core.workflow.nodes.transform import TransformNode -from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.end import EndNode class NodeType(StrEnum): @@ -25,12 +16,21 @@ class NodeType(StrEnum): AGENT = "agent" -WorkflowNode = Union[ - BaseNode, - StartNode, - EndNode, - LLMNode, - IfElseNode, - AgentNode, - TransformNode, -] +class ComparisonOperator(StrEnum): + EMPTY = "empty" + NOT_EMPTY = "not_empty" + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + START_WITH = "startwith" + END_WITH = "endwith" + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class LogicOperator(StrEnum): + AND = "and" + OR = "or" diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 0e759569..4e424b54 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -1,27 +1,8 @@ """Condition Configuration""" from pydantic import Field, BaseModel, field_validator -from enum import StrEnum + from app.core.workflow.nodes.base_config import BaseNodeConfig - - -class LogicOperator(StrEnum): - AND = "and" - OR = "or" - - -class ComparisonOperator(StrEnum): - EMPTY = "empty" - NOT_EMPTY = "not_empty" - CONTAINS = "contains" - NOT_CONTAINS = "not_contains" - START_WITH = "startwith" - END_WITH = "endwith" - EQ = "eq" - NE = "ne" - LT = "lt" - LE = "le" - GT = "gt" - GE = "ge" +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator class ConditionDetail(BaseModel): @@ -77,7 +58,7 @@ class IfElseNodeConfig(BaseNodeConfig): { "logical_operator": "and", "conditions": [ - { + [ { "left": "node.userinput.message", "comparison_operator": "eq", @@ -88,14 +69,14 @@ class IfElseNodeConfig(BaseNodeConfig): "comparison_operator": "eq", "right": "True" } - } + ] ] }, # CASE1 / ELIF Branch { "logical_operator": "or", "conditions": [ - { + [ { "left": "node.userinput.test", "comparison_operator": "eq", @@ -106,7 +87,7 @@ class IfElseNodeConfig(BaseNodeConfig): "comparison_operator": "contains", "right": "'123'" } - } + ] ] } # CASE3 / ELSE Branch diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index fcfbd9ac..ed3dbbd6 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -1,11 +1,10 @@ import logging from typing import Any -from simpleeval import NameNotDefined, InvalidExpression - from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig -from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOperator +from app.core.workflow.nodes.if_else.config import ConditionDetail logger = logging.getLogger(__name__) diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index e1f32308..1abace67 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -5,19 +5,29 @@ """ import logging -from typing import Any +from typing import Any, Union +from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.enums import NodeType, WorkflowNode +from app.core.workflow.nodes.end import EndNode +from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode -from app.core.workflow.nodes.agent import AgentNode -from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.end import EndNode +from app.core.workflow.nodes.transform import TransformNode logger = logging.getLogger(__name__) +WorkflowNode = Union[ + BaseNode, + StartNode, + EndNode, + LLMNode, + IfElseNode, + AgentNode, + TransformNode, +] + class NodeFactory: """节点工厂 @@ -54,9 +64,9 @@ class NodeFactory: @classmethod def create_node( - cls, - node_config: dict[str, Any], - workflow_config: dict[str, Any] + cls, + node_config: dict[str, Any], + workflow_config: dict[str, Any] ) -> WorkflowNode | None: """创建节点实例 From 00a5016c066f559ebd19b2295fe9da807ca46302 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 15:59:28 +0800 Subject: [PATCH 17/24] feat(workflow): add import for if-else node configuration --- api/app/core/workflow/nodes/configs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 99d06036..15ab0ce9 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig +from app.core.workflow.nodes.if_else.config import IfElseNodeConfig __all__ = [ # 基础类 @@ -26,4 +27,5 @@ __all__ = [ "MessageConfig", "AgentNodeConfig", "TransformNodeConfig", + "IfElseNodeConfig", ] From 056411f47dfe8eefa7a74630c63fe2bcb34ea867 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 19 Dec 2025 18:21:54 +0800 Subject: [PATCH 18/24] [add] migration script --- .../versions/70e94dd4a8d1_202512191820.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 api/migrations/versions/70e94dd4a8d1_202512191820.py diff --git a/api/migrations/versions/70e94dd4a8d1_202512191820.py b/api/migrations/versions/70e94dd4a8d1_202512191820.py new file mode 100644 index 00000000..114340a5 --- /dev/null +++ b/api/migrations/versions/70e94dd4a8d1_202512191820.py @@ -0,0 +1,40 @@ +"""202512191820 + +Revision ID: 70e94dd4a8d1 +Revises: f96a53af914c +Create Date: 2025-12-19 18:20:21.998247 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '70e94dd4a8d1' +down_revision: Union[str, None] = 'f96a53af914c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_prompt_model_config_id'), table_name='prompt_model_config') + op.drop_table('prompt_model_config') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('prompt_model_config', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False, comment='Tenant ID'), + sa.Column('system_prompt', sa.TEXT(), autoincrement=False, nullable=False, comment='System Prompt'), + sa.Column('created_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True, comment='Creation Time'), + sa.Column('updated_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True, comment='Update Time'), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], name=op.f('prompt_model_config_tenant_id_fkey')), + sa.PrimaryKeyConstraint('id', name=op.f('prompt_model_config_pkey')) + ) + op.create_index(op.f('ix_prompt_model_config_id'), 'prompt_model_config', ['id'], unique=False) + # ### end Alembic commands ### From 6ecf5edfb32b075d968d5d8e16dbc130f2d0ffb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Fri, 19 Dec 2025 10:37:28 +0000 Subject: [PATCH 19/24] Merge #19 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 统一输出 * fix/memory_reflection: (35 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 统一输出 - 统一输出 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py - 统一输出 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 Signed-off-by: aliyun8644380055 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/19 --- .../memory_reflection_controller.py | 50 ++++++++----------- .../reflection_engine/self_reflexion.py | 5 +- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index bd9e0e09..8dfa6c50 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,4 +1,5 @@ import asyncio +import time from dotenv import load_dotenv from fastapi import APIRouter, Depends, HTTPException, status @@ -6,17 +7,17 @@ from sqlalchemy.orm import Session from sqlalchemy import text from app.core.logging_config import get_api_logger +from app.core.response_utils import success from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine from app.dependencies import get_current_user from app.db import get_db from app.models.user_model import User from app.repositories.data_config_repository import DataConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService - from app.schemas.memory_reflection_schemas import Memory_Reflection from app.services.model_service import ModelConfigService + load_dotenv() api_logger = get_api_logger() @@ -80,13 +81,8 @@ async def save_reflection_config( ) api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") - - # 返回结果 - return { - "status": "成功", - "message": "反思配置已保存", - "config_id": config_id, - "database_record": { + + reflection_result={ "config_id": result.config_id, "enable_self_reflexion": result.enable_self_reflexion, "iteration_period": result.iteration_period, @@ -95,9 +91,11 @@ async def save_reflection_config( "reflection_model_id": result.reflection_model_id, "memory_verify": result.memory_verify, "quality_assessment": result.quality_assessment, - "user_id": result.user_id - } - } + "user_id": result.user_id} + + return success(data=reflection_result, msg="反思配置成功") + + except ValueError as ve: api_logger.error(f"参数错误: {str(ve)}") @@ -156,13 +154,7 @@ async def start_workspace_reflection( "reflection_result": reflection_result }) - return { - "status": "完成", - "message": f"成功处理 {len(reflection_results)} 个反思任务", - "workspace_id": str(workspace_id), - "reflection_count": len(reflection_results), - "reflection_results": reflection_results - } + return success(data=reflection_results, msg="反思配置成功") except Exception as e: api_logger.error(f"启动workspace反思失败: {str(e)}") @@ -179,7 +171,6 @@ async def start_reflection_configs( db: Session = Depends(get_db), ) -> dict: """通过config_id查询data_config表中的反思配置信息""" - try: api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") @@ -196,8 +187,8 @@ async def start_reflection_configs( # 构建返回数据 reflection_config = { "config_id": result.config_id, - "enable_self_reflexion": result.enable_self_reflexion, - "iteration_period": result.iteration_period, + "reflection_enabled": result.enable_self_reflexion, + "reflection_period_in_hours": result.iteration_period, "reflexion_range": result.reflexion_range, "baseline": result.baseline, "reflection_model_id": result.reflection_model_id, @@ -205,15 +196,10 @@ async def start_reflection_configs( "quality_assessment": result.quality_assessment, "user_id": result.user_id } - api_logger.info(f"成功查询反思配置,config_id: {config_id}") + return success(data=reflection_config, msg="反思配置查询成功") - return { - "status": "成功", - "message": "反思配置查询成功", - "data": reflection_config - } - + except HTTPException: # 重新抛出HTTP异常 raise @@ -276,4 +262,8 @@ async def reflection_run( ) result=await (engine.reflection_run()) - return result \ No newline at end of file + return success(data=result, msg="反思试运行") + + + + diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 8f5b9bae..6ccec500 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -19,6 +19,7 @@ import uuid from pydantic import BaseModel +from app.core.response_utils import success from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all from app.repositories.neo4j.neo4j_update import neo4j_data from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -314,8 +315,8 @@ class ReflectionEngine: for result in item['results']: reflexion_data.append(result['reflexion']) result_data['reflexion_data'] = reflexion_data - execution_time = time.time() - start_time - return {"status": "SUCCESS", "message": "反思试运行", "data": result_data, "time": execution_time} + return result_data + async def extract_fields_from_json(self): """从example.json中提取source_data和databasets字段""" From cdeace7e585a84439e6c2af65c4c685c1e9ffb57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Sat, 20 Dec 2025 07:02:46 +0000 Subject: [PATCH 20/24] Merge #21 into develop from feature/emotion-engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feature/情绪引擎 * feature/emotion-engine: (7 commits squashed) - [feature]Emotion Engine Development - [feature]Emotion Engine Development - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]fix vite.config.ts Signed-off-by: 乐力齐 Commented-by: aliyun6762716068 Commented-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/21 --- api/app/controllers/__init__.py | 4 + .../controllers/emotion_config_controller.py | 207 ++++++ api/app/controllers/emotion_controller.py | 255 +++++++ .../agent/langgraph_graph/write_graph.py | 71 +- .../core/memory/agent/utils/write_tools.py | 11 + api/app/core/memory/models/emotion_models.py | 85 +++ api/app/core/memory/models/graph_models.py | 75 +- api/app/core/memory/models/message_models.py | 11 + .../deduplication/entity_dedup_llm.py | 1 - .../extraction_orchestrator.py | 173 ++++- api/app/core/memory/utils/config/overrides.py | 18 +- .../core/memory/utils/prompt/prompt_utils.py | 78 ++ .../prompt/prompts/extract_emotion.jinja2 | 57 ++ .../generate_emotion_suggestions.jinja2 | 63 ++ api/app/models/data_config_model.py | 9 +- api/app/repositories/neo4j/add_nodes.py | 8 +- api/app/repositories/neo4j/cypher_queries.py | 19 +- .../repositories/neo4j/emotion_repository.py | 246 +++++++ .../neo4j/statement_repository.py | 15 +- api/app/schemas/emotion_schema.py | 32 + api/app/services/emotion_analytics_service.py | 670 ++++++++++++++++++ api/app/services/emotion_config_service.py | 212 ++++++ .../services/emotion_extraction_service.py | 200 ++++++ 23 files changed, 2453 insertions(+), 67 deletions(-) create mode 100644 api/app/controllers/emotion_config_controller.py create mode 100644 api/app/controllers/emotion_controller.py create mode 100644 api/app/core/memory/models/emotion_models.py create mode 100644 api/app/core/memory/utils/prompt/prompts/extract_emotion.jinja2 create mode 100644 api/app/core/memory/utils/prompt/prompts/generate_emotion_suggestions.jinja2 create mode 100644 api/app/repositories/neo4j/emotion_repository.py create mode 100644 api/app/schemas/emotion_schema.py create mode 100644 api/app/services/emotion_analytics_service.py create mode 100644 api/app/services/emotion_config_service.py create mode 100644 api/app/services/emotion_extraction_service.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 47cc8688..5cfbe536 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -29,6 +29,8 @@ from . import ( public_share_controller, multi_agent_controller, workflow_controller, + emotion_controller, + emotion_config_controller, prompt_optimizer_controller, tool_controller, tool_execution_controller, @@ -62,6 +64,8 @@ manager_router.include_router(public_share_controller.router) # 公开路由( manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) manager_router.include_router(workflow_controller.router) +manager_router.include_router(emotion_controller.router) +manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(tool_controller.router) diff --git a/api/app/controllers/emotion_config_controller.py b/api/app/controllers/emotion_config_controller.py new file mode 100644 index 00000000..76450d8a --- /dev/null +++ b/api/app/controllers/emotion_config_controller.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +"""情绪配置控制器模块 + +本模块提供情绪引擎配置管理的API端点,包括获取和更新配置。 + +Routes: + GET /memory/config/emotion - 获取情绪引擎配置 + POST /memory/config/emotion - 更新情绪引擎配置 +""" + +from fastapi import APIRouter, Depends, Query, HTTPException, status +from pydantic import BaseModel, Field +from typing import Optional +from sqlalchemy.orm import Session + +from app.core.response_utils import success +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse +from app.services.emotion_config_service import EmotionConfigService +from app.core.logging_config import get_api_logger +from app.db import get_db + +# 获取API专用日志器 +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/emotion", + tags=["Emotion Config"], + dependencies=[Depends(get_current_user)] # 所有路由都需要认证 +) + +class EmotionConfigQuery(BaseModel): + """情绪配置查询请求模型""" + config_id: int = Field(..., description="配置ID") + +class EmotionConfigUpdate(BaseModel): + """情绪配置更新请求模型""" + config_id: int = Field(..., description="配置ID") + emotion_enabled: bool = Field(..., description="是否启用情绪提取") + emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID") + emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词") + emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="最小情绪强度阈值(0.0-1.0)") + emotion_enable_subject: bool = Field(..., description="是否启用主体分类") + +@router.get("/read_config", response_model=ApiResponse) +def get_emotion_config( + config_id: int = Query(..., description="配置ID"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """获取情绪引擎配置 + + 查询指定配置ID的情绪相关配置字段。 + + Args: + config_id: 配置ID + + Returns: + ApiResponse: 包含情绪配置数据 + + Example Response: + { + "code": 2000, + "msg": "情绪配置获取成功", + "data": { + "config_id": 17, + "emotion_enabled": true, + "emotion_model_id": "gpt-4", + "emotion_extract_keywords": true, + "emotion_min_intensity": 0.1, + "emotion_enable_subject": true + } + } + """ + try: + api_logger.info( + f"用户 {current_user.username} 请求获取情绪配置", + extra={"config_id": config_id} + ) + + # 初始化服务 + config_service = EmotionConfigService(db) + + # 调用服务层 + data = config_service.get_emotion_config(config_id) + + api_logger.info( + "情绪配置获取成功", + extra={ + "config_id": config_id, + "emotion_enabled": data.get("emotion_enabled", False) + } + ) + + return success(data=data, msg="情绪配置获取成功") + + except ValueError as e: + api_logger.warning( + f"获取情绪配置失败: {str(e)}", + extra={"config_id": config_id} + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) + ) + except Exception as e: + api_logger.error( + f"获取情绪配置失败: {str(e)}", + extra={"config_id": config_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取情绪配置失败: {str(e)}" + ) + + + +@router.post("/updated_config", response_model=ApiResponse) +def update_emotion_config( + config: EmotionConfigUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """更新情绪引擎配置 + + 更新指定配置ID的情绪相关配置字段。 + + Args: + config: 配置更新数据(包含config_id) + + Returns: + ApiResponse: 包含更新后的情绪配置数据 + + Example Request: + { + "config_id": 2, + "emotion_enabled": true, + "emotion_model_id": "gpt-4", + "emotion_extract_keywords": true, + "emotion_min_intensity": 0.1, + "emotion_enable_subject": true + } + + Example Response: + { + "code": 2000, + "msg": "情绪配置更新成功", + "data": { + "config_id": 17, + "emotion_enabled": true, + "emotion_model_id": "gpt-4", + "emotion_extract_keywords": true, + "emotion_min_intensity": 0.2, + "emotion_enable_subject": true + } + } + """ + try: + api_logger.info( + f"用户 {current_user.username} 请求更新情绪配置", + extra={ + "config_id": config.config_id, + "emotion_enabled": config.emotion_enabled, + "emotion_min_intensity": config.emotion_min_intensity + } + ) + + # 初始化服务 + config_service = EmotionConfigService(db) + + # 转换为字典(排除config_id,因为它作为参数传递) + config_data = config.model_dump(exclude={'config_id'}) + + # 调用服务层 + data = config_service.update_emotion_config(config.config_id, config_data) + + api_logger.info( + "情绪配置更新成功", + extra={ + "config_id": config.config_id, + "emotion_enabled": data.get("emotion_enabled", False) + } + ) + + return success(data=data, msg="情绪配置更新成功") + + except ValueError as e: + api_logger.warning( + f"更新情绪配置失败: {str(e)}", + extra={"config_id": config.config_id} + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + api_logger.error( + f"更新情绪配置失败: {str(e)}", + extra={"config_id": config.config_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"更新情绪配置失败: {str(e)}" + ) diff --git a/api/app/controllers/emotion_controller.py b/api/app/controllers/emotion_controller.py new file mode 100644 index 00000000..2ed00c43 --- /dev/null +++ b/api/app/controllers/emotion_controller.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +"""情绪分析控制器模块 + +本模块提供情绪分析相关的API端点,包括情绪标签、词云、健康指数和个性化建议。 + +Routes: + POST /emotion/tags - 获取情绪标签统计 + POST /emotion/wordcloud - 获取情绪词云数据 + POST /emotion/health - 获取情绪健康指数 + POST /emotion/suggestions - 获取个性化情绪建议 +""" + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.core.response_utils import success, fail +from app.core.error_codes import BizCode +from app.dependencies import get_current_user, get_db +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse +from app.schemas.emotion_schema import ( + EmotionTagsRequest, + EmotionWordcloudRequest, + EmotionHealthRequest, + EmotionSuggestionsRequest +) +from app.services.emotion_analytics_service import EmotionAnalyticsService +from app.core.logging_config import get_api_logger + +# 获取API专用日志器 +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/emotion", + tags=["Emotion Analysis"], + dependencies=[Depends(get_current_user)] # 所有路由都需要认证 +) + + +# 初始化情绪分析服务uv +emotion_service = EmotionAnalyticsService() + + + +@router.post("/tags", response_model=ApiResponse) +async def get_emotion_tags( + request: EmotionTagsRequest, + current_user: User = Depends(get_current_user), +): + + try: + api_logger.info( + f"用户 {current_user.username} 请求获取情绪标签统计", + extra={ + "group_id": request.group_id, + "emotion_type": request.emotion_type, + "start_date": request.start_date, + "end_date": request.end_date, + "limit": request.limit + } + ) + + # 调用服务层 + data = await emotion_service.get_emotion_tags( + end_user_id=request.group_id, + emotion_type=request.emotion_type, + start_date=request.start_date, + end_date=request.end_date, + limit=request.limit + ) + + api_logger.info( + "情绪标签统计获取成功", + extra={ + "group_id": request.group_id, + "total_count": data.get("total_count", 0), + "tags_count": len(data.get("tags", [])) + } + ) + + return success(data=data, msg="情绪标签获取成功") + + except Exception as e: + api_logger.error( + f"获取情绪标签统计失败: {str(e)}", + extra={"group_id": request.group_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取情绪标签统计失败: {str(e)}" + ) + + + +@router.post("/wordcloud", response_model=ApiResponse) +async def get_emotion_wordcloud( + request: EmotionWordcloudRequest, + current_user: User = Depends(get_current_user), +): + + try: + api_logger.info( + f"用户 {current_user.username} 请求获取情绪词云数据", + extra={ + "group_id": request.group_id, + "emotion_type": request.emotion_type, + "limit": request.limit + } + ) + + # 调用服务层 + data = await emotion_service.get_emotion_wordcloud( + end_user_id=request.group_id, + emotion_type=request.emotion_type, + limit=request.limit + ) + + api_logger.info( + "情绪词云数据获取成功", + extra={ + "group_id": request.group_id, + "total_keywords": data.get("total_keywords", 0) + } + ) + + return success(data=data, msg="情绪词云获取成功") + + except Exception as e: + api_logger.error( + f"获取情绪词云数据失败: {str(e)}", + extra={"group_id": request.group_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取情绪词云数据失败: {str(e)}" + ) + + + +@router.post("/health", response_model=ApiResponse) +async def get_emotion_health( + request: EmotionHealthRequest, + current_user: User = Depends(get_current_user), +): + + try: + # 验证时间范围参数 + if request.time_range not in ["7d", "30d", "90d"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="时间范围参数无效,必须是 7d、30d 或 90d" + ) + + api_logger.info( + f"用户 {current_user.username} 请求获取情绪健康指数", + extra={ + "group_id": request.group_id, + "time_range": request.time_range + } + ) + + # 调用服务层 + data = await emotion_service.calculate_emotion_health_index( + end_user_id=request.group_id, + time_range=request.time_range + ) + + api_logger.info( + "情绪健康指数获取成功", + extra={ + "group_id": request.group_id, + "health_score": data.get("health_score", 0), + "level": data.get("level", "未知") + } + ) + + return success(data=data, msg="情绪健康指数获取成功") + + except HTTPException: + raise + except Exception as e: + api_logger.error( + f"获取情绪健康指数失败: {str(e)}", + extra={"group_id": request.group_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取情绪健康指数失败: {str(e)}" + ) + + + +@router.post("/suggestions", response_model=ApiResponse) +async def get_emotion_suggestions( + request: EmotionSuggestionsRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """获取个性化情绪建议 + + Args: + request: 包含 group_id 和可选的 config_id + db: 数据库会话 + current_user: 当前用户 + + Returns: + 个性化情绪建议响应 + """ + try: + # 验证 config_id(如果提供) + config_id = request.config_id + if config_id is not None: + from app.controllers.memory_agent_controller import validate_config_id + try: + config_id = validate_config_id(config_id, db) + except ValueError as e: + return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e)) + + api_logger.info( + f"用户 {current_user.username} 请求获取个性化情绪建议", + extra={ + "group_id": request.group_id, + "config_id": config_id + } + ) + + # 调用服务层 + data = await emotion_service.generate_emotion_suggestions( + end_user_id=request.group_id, + config_id=config_id + ) + + api_logger.info( + "个性化建议获取成功", + extra={ + "group_id": request.group_id, + "suggestions_count": len(data.get("suggestions", [])) + } + ) + + return success(data=data, msg="个性化建议获取成功") + + except Exception as e: + api_logger.error( + f"获取个性化建议失败: {str(e)}", + extra={"group_id": request.group_id}, + exc_info=True + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取个性化建议失败: {str(e)}" + ) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index dbdc51d6..cfcc1c4a 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -38,14 +38,53 @@ async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None): messages = state["messages"] last_message = messages[-1] - result = await data_type_tool.ainvoke({ - "context": last_message[1] if isinstance(last_message, tuple) else last_message.content - }) - result=json.loads( result) + # 调用 Data_type_differentiation 工具 + try: + raw_result = await data_type_tool.ainvoke({ + "context": last_message[1] if isinstance(last_message, tuple) else last_message.content + }) + + # MCP工具返回的是列表格式,需要提取内容 + logger.debug(f"Data_type_differentiation raw result type: {type(raw_result)}, value: {raw_result}") + + # 处理不同的返回格式 + if isinstance(raw_result, list) and len(raw_result) > 0: + # MCP工具返回格式: [{"type": "text", "text": "..."}] + result_text = raw_result[0].get("text", "{}") if isinstance(raw_result[0], dict) else str(raw_result[0]) + elif isinstance(raw_result, str): + result_text = raw_result + else: + result_text = str(raw_result) + + # 解析JSON字符串 + try: + result = json.loads(result_text) + except json.JSONDecodeError as je: + logger.error(f"Failed to parse result as JSON: {result_text}, error: {je}") + return {"messages": [AIMessage(content=json.dumps({ + "status": "error", + "message": f"Invalid JSON response from Data_type_differentiation: {str(je)}" + }))]} + + # 检查是否有错误 + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("message", "Unknown error in Data_type_differentiation") + logger.error(f"Data_type_differentiation 返回错误: {error_msg}") + return {"messages": [AIMessage(content=json.dumps({ + "status": "error", + "message": error_msg + }))]} + + except Exception as e: + logger.error(f"调用 Data_type_differentiation 失败: {e}", exc_info=True) + return {"messages": [AIMessage(content=json.dumps({ + "status": "error", + "message": f"Data type differentiation failed: {str(e)}" + }))]} # 调用 Data_write,传递 config_id write_params = { - "content": result["context"], + "content": result.get("context", last_message.content if hasattr(last_message, 'content') else str(last_message)), "apply_id": apply_id, "group_id": group_id, "user_id": user_id @@ -56,14 +95,22 @@ async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None): write_params["config_id"] = config_id logger.debug(f"传递 config_id 到 Data_write: {config_id}") - write_result = await data_write_tool.ainvoke(write_params) + try: + write_result = await data_write_tool.ainvoke(write_params) - if isinstance(write_result, dict): - content = write_result.get("data", str(write_result)) - else: - content = str(write_result) - logger.info("写入内容: %s", content) - return {"messages": [AIMessage(content=content)]} + if isinstance(write_result, dict): + content = write_result.get("data", str(write_result)) + else: + content = str(write_result) + logger.info("写入内容: %s", content) + return {"messages": [AIMessage(content=content)]} + + except Exception as e: + logger.error(f"调用 Data_write 失败: {e}", exc_info=True) + return {"messages": [AIMessage(content=json.dumps({ + "status": "error", + "message": f"Data write failed: {str(e)}" + }))]} workflow = StateGraph(WriteState) workflow.add_node("content_input", call_model) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index ebfbcc6c..f792ea9d 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -39,6 +39,17 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id ref_id: 参考ID,默认为 "wyl20251027" config_id: 配置ID,用于标记数据处理配置 """ + # 如果提供了config_id,重新加载配置 + if config_id: + from app.core.memory.utils.config.definitions import reload_configuration_from_database + logger.info(f"Reloading configuration for config_id: {config_id}") + config_loaded = reload_configuration_from_database(config_id) + if not config_loaded: + error_msg = f"Failed to load configuration for config_id: {config_id}" + logger.error(error_msg) + raise ValueError(error_msg) + logger.info(f"Configuration reloaded successfully for config_id: {config_id}") + logger.info("=== MemSci Knowledge Extraction Pipeline ===") logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}") logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}") diff --git a/api/app/core/memory/models/emotion_models.py b/api/app/core/memory/models/emotion_models.py new file mode 100644 index 00000000..f84165a7 --- /dev/null +++ b/api/app/core/memory/models/emotion_models.py @@ -0,0 +1,85 @@ +"""Emotion extraction models for LLM structured output. + +This module contains Pydantic models for emotion extraction from statements, +designed to be used with LLM structured output capabilities. + +Classes: + EmotionExtraction: Model for emotion extraction results from statements +""" + +from pydantic import BaseModel, Field, field_validator +from typing import List, Optional + + +class EmotionExtraction(BaseModel): + """Emotion extraction result model for LLM structured output. + + This model represents the structured emotion information extracted from + a statement using LLM. It includes emotion type, intensity, keywords, + subject classification, and optional target. + + Attributes: + emotion_type: Type of emotion (joy/sadness/anger/fear/surprise/neutral) + emotion_intensity: Intensity of emotion (0.0-1.0) + emotion_keywords: List of emotion keywords from the statement (max 3) + emotion_subject: Subject of emotion (self/other/object) + emotion_target: Optional target of emotion (person or object name) + """ + + emotion_type: str = Field( + ..., + description="Emotion type: joy/sadness/anger/fear/surprise/neutral" + ) + emotion_intensity: float = Field( + ..., + ge=0.0, + le=1.0, + description="Emotion intensity from 0.0 to 1.0" + ) + emotion_keywords: List[str] = Field( + default_factory=list, + description="Emotion keywords extracted from the statement (max 3)" + ) + emotion_subject: str = Field( + ..., + description="Emotion subject: self/other/object" + ) + emotion_target: Optional[str] = Field( + None, + description="Emotion target: person or object name" + ) + + @field_validator('emotion_type') + @classmethod + def validate_emotion_type(cls, v): + """Validate emotion type is one of the valid values.""" + valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral'] + if v not in valid_types: + raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") + return v + + @field_validator('emotion_subject') + @classmethod + def validate_emotion_subject(cls, v): + """Validate emotion subject is one of the valid values.""" + valid_subjects = ['self', 'other', 'object'] + if v not in valid_subjects: + raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") + return v + + @field_validator('emotion_keywords') + @classmethod + def validate_emotion_keywords(cls, v): + """Validate and limit emotion keywords to max 3 items.""" + if not isinstance(v, list): + return [] + # Limit to max 3 keywords + return v[:3] + + @field_validator('emotion_intensity') + @classmethod + def validate_emotion_intensity(cls, v): + """Validate emotion intensity is within valid range.""" + if not (0.0 <= v <= 1.0): + raise ValueError(f"emotion_intensity must be between 0.0 and 1.0, got {v}") + return v diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 58b8271c..a8c3f7b0 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -215,24 +215,58 @@ class StatementNode(Node): Attributes: chunk_id: ID of the parent chunk this statement belongs to stmt_type: Type of the statement (from ontology) - temporal_info: Temporal information extracted from the statement statement: The actual statement text content - connect_strength: Classification of connection strength ('Strong' or 'Weak') + emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node + emotion_target: Optional emotion target (person or object name) + emotion_subject: Optional emotion subject (self/other/object) + emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral) + emotion_keywords: Optional list of emotion keywords (max 3) + temporal_info: Temporal information extracted from the statement valid_at: Optional start date of temporal validity invalid_at: Optional end date of temporal validity statement_embedding: Optional embedding vector for the statement chunk_embedding: Optional embedding vector for the parent chunk + connect_strength: Classification of connection strength ('Strong' or 'Weak') config_id: Configuration ID used to process this statement """ + # Core fields (ordered as requested) chunk_id: str = Field(..., description="ID of the parent chunk") stmt_type: str = Field(..., description="Type of the statement") - temporal_info: TemporalInfo = Field(..., description="Temporal information") statement: str = Field(..., description="The statement text content") - connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") + + # Emotion fields (ordered as requested, emotion_intensity first for display) + emotion_intensity: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="Emotion intensity: 0.0-1.0 (displayed on node)" + ) + emotion_target: Optional[str] = Field( + None, + description="Emotion target: person or object name" + ) + emotion_subject: Optional[str] = Field( + None, + description="Emotion subject: self/other/object" + ) + emotion_type: Optional[str] = Field( + None, + description="Emotion type: joy/sadness/anger/fear/surprise/neutral" + ) + emotion_keywords: Optional[List[str]] = Field( + default_factory=list, + description="Emotion keywords list, max 3 items" + ) + + # Temporal fields + temporal_info: TemporalInfo = Field(..., description="Temporal information") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") + + # Embedding and other fields statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") + connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") @field_validator('valid_at', 'invalid_at', mode='before') @@ -240,6 +274,39 @@ class StatementNode(Node): def validate_datetime(cls, v): """使用通用的历史日期解析函数""" return parse_historical_datetime(v) + + @field_validator('emotion_type', mode='before') + @classmethod + def validate_emotion_type(cls, v): + """Validate emotion type is one of the valid values""" + if v is None: + return v + valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral'] + if v not in valid_types: + raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") + return v + + @field_validator('emotion_subject', mode='before') + @classmethod + def validate_emotion_subject(cls, v): + """Validate emotion subject is one of the valid values""" + if v is None: + return v + valid_subjects = ['self', 'other', 'object'] + if v not in valid_subjects: + raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") + return v + + @field_validator('emotion_keywords', mode='before') + @classmethod + def validate_emotion_keywords(cls, v): + """Validate emotion keywords list has max 3 items""" + if v is None: + return [] + if not isinstance(v, list): + return [] + # Limit to max 3 keywords + return v[:3] class ChunkNode(Node): diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 192816fd..199bdd75 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -64,6 +64,11 @@ class Statement(BaseModel): connect_strength: Optional connection strength ('Strong' or 'Weak') temporal_validity: Optional temporal validity range triplet_extraction_info: Optional triplet extraction results + emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral) + emotion_intensity: Optional emotion intensity (0.0-1.0) + emotion_keywords: Optional list of emotion keywords + emotion_subject: Optional emotion subject (self/other/object) + emotion_target: Optional emotion target (person or object name) """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") @@ -80,6 +85,12 @@ class Statement(BaseModel): triplet_extraction_info: Optional[TripletExtractionResponse] = Field( None, description="The triplet extraction information of the statement." ) + # Emotion fields + emotion_type: Optional[str] = Field(None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral") + emotion_intensity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0") + emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3") + emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object") + emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name") class ConversationContext(BaseModel): diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 2c784d42..734f7b69 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -480,7 +480,6 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 - global_redirect: dict losing_id -> canonical_id accumulated across rounds - records: textual logs including per-round/per-block summaries and per-pair decisions """ - import asyncio import random # 初始化全局日志和全局ID映射(存储所有轮次的结果) records: List[str] = [] diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index e00bcf0a..91529aa9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -35,7 +35,6 @@ from app.core.memory.models.graph_models import ( from app.core.memory.utils.data.ontology import TemporalInfo from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, - StatementExtractionConfig, ) from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient @@ -53,7 +52,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.tem ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, - embedding_generation_all, generate_entity_embeddings_from_triplets, ) from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( @@ -179,24 +177,12 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 - if self.progress_callback: - extraction_stats = { - "statements_count": total_statements, - "entities_count": 0, # 暂时为0,后续会更新 - "triplets_count": 0, # 暂时为0,后续会更新 - "temporal_ranges_count": 0, # 暂时为0,后续会更新 - } - await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) - - # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 - await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") - - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") + # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -225,6 +211,7 @@ class ExtractionOrchestrator: dialog_data_list, temporal_maps, triplet_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -552,9 +539,108 @@ class ExtractionOrchestrator: return temporal_maps + async def _extract_emotions( + self, dialog_data_list: List[DialogData] + ) -> List[Dict[str, Any]]: + """ + 从对话中提取情绪信息(优化版:全局陈述句级并行) + + Args: + dialog_data_list: 对话数据列表 + + Returns: + 情绪信息映射列表,每个对话对应一个字典 + """ + logger.info("开始情绪信息提取(全局陈述句级并行)") + + # 收集所有陈述句及其配置 + all_statements = [] + statement_metadata = [] # (dialog_idx, statement_id) + + # 获取第一个对话的config_id来加载配置 + config_id = None + if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): + config_id = dialog_data_list[0].config_id + + # 加载DataConfig + data_config = None + if config_id: + try: + from app.db import SessionLocal + from app.repositories.data_config_repository import DataConfigRepository + + db = SessionLocal() + try: + data_config = DataConfigRepository.get_by_id(db, config_id) + finally: + db.close() + + if data_config and not data_config.emotion_enabled: + logger.info("情绪提取已在配置中禁用,跳过情绪提取") + return [{} for _ in dialog_data_list] + + except Exception as e: + logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") + return [{} for _ in dialog_data_list] + else: + logger.info("未找到config_id,跳过情绪提取") + return [{} for _ in dialog_data_list] + + # 如果配置未启用情绪提取,直接返回空映射 + if not data_config or not data_config.emotion_enabled: + logger.info("情绪提取未启用,跳过") + return [{} for _ in dialog_data_list] + + # 收集所有陈述句 + for d_idx, dialog in enumerate(dialog_data_list): + for chunk in dialog.chunks: + for statement in chunk.statements: + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + + logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + + # 初始化情绪提取服务 + from app.services.emotion_extraction_service import EmotionExtractionService + emotion_service = EmotionExtractionService( + llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None + ) + + # 全局并行处理所有陈述句 + async def extract_for_statement(stmt_data): + statement, config = stmt_data + try: + return await emotion_service.extract_emotion(statement.statement, config) + except Exception as e: + logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}") + return None + + tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 将结果组织成对话级别的映射 + emotion_maps = [{} for _ in dialog_data_list] + successful_extractions = 0 + + for i, result in enumerate(results): + d_idx, stmt_id = statement_metadata[i] + if isinstance(result, Exception): + logger.error(f"陈述句处理异常: {result}") + emotion_maps[d_idx][stmt_id] = None + else: + emotion_maps[d_idx][stmt_id] = result + if result is not None: + successful_extractions += 1 + + # 统计提取结果 + logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪") + + return emotion_maps + async def _parallel_extract_and_embed( self, dialog_data_list: List[DialogData] ) -> Tuple[ + List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, List[float]]], @@ -562,35 +648,39 @@ class ExtractionOrchestrator: List[List[float]], ]: """ - 并行执行三元组提取、时间信息提取和基础嵌入生成 + 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 - 这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: + 这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: - 三元组提取:从陈述句中提取实体和关系 - 时间信息提取:从陈述句中提取时间范围 + - 情绪提取:从陈述句中提取情绪信息 - 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组) Args: dialog_data_list: 对话数据列表 Returns: - 五个列表的元组: + 六个列表的元组: - 三元组映射列表 - 时间信息映射列表 + - 情绪映射列表 - 陈述句嵌入映射列表 - 分块嵌入映射列表 - 对话嵌入列表 """ - logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成") + logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成") - # 创建三个并行任务 + # 创建四个并行任务 triplet_task = self._extract_triplets(dialog_data_list) temporal_task = self._extract_temporal(dialog_data_list) + emotion_task = self._extract_emotions(dialog_data_list) embedding_task = self._generate_basic_embeddings(dialog_data_list) # 并行执行 results = await asyncio.gather( triplet_task, temporal_task, + emotion_task, embedding_task, return_exceptions=True ) @@ -598,19 +688,21 @@ class ExtractionOrchestrator: # 解包结果 triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] + emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - if isinstance(results[2], Exception): - logger.error(f"基础嵌入生成失败: {results[2]}") + if isinstance(results[3], Exception): + logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] chunk_embedding_maps = [{} for _ in dialog_data_list] dialog_embeddings = [[] for _ in dialog_data_list] else: - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2] + statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3] logger.info("并行任务执行完成") return ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -727,6 +819,7 @@ class ExtractionOrchestrator: dialog_data_list: List[DialogData], temporal_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], statement_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]], dialog_embeddings: List[List[float]], @@ -738,6 +831,7 @@ class ExtractionOrchestrator: dialog_data_list: 对话数据列表 temporal_maps: 时间信息映射列表 triplet_maps: 三元组映射列表 + emotion_maps: 情绪信息映射列表 statement_embedding_maps: 陈述句嵌入映射列表 chunk_embedding_maps: 分块嵌入映射列表 dialog_embeddings: 对话嵌入列表 @@ -752,6 +846,7 @@ class ExtractionOrchestrator: if ( len(temporal_maps) != expected_length or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length or len(statement_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length or len(dialog_embeddings) != expected_length @@ -759,6 +854,7 @@ class ExtractionOrchestrator: logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, " + f"情绪映射: {len(emotion_maps)}, " f"陈述句嵌入: {len(statement_embedding_maps)}, " f"分块嵌入: {len(chunk_embedding_maps)}, " f"对话嵌入: {len(dialog_embeddings)}" @@ -767,6 +863,7 @@ class ExtractionOrchestrator: total_statements = 0 assigned_temporal = 0 assigned_triplets = 0 + assigned_emotions = 0 assigned_statement_embeddings = 0 assigned_chunk_embeddings = 0 assigned_dialog_embeddings = 0 @@ -774,12 +871,13 @@ class ExtractionOrchestrator: # 处理每个对话 for i, dialog_data in enumerate(dialog_data_list): # 检查是否有缺失的数据 - if i >= len(temporal_maps) or i >= len(triplet_maps): + if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps): logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值") continue temporal_map = temporal_maps[i] triplet_map = triplet_maps[i] + emotion_map = emotion_maps[i] statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {} chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {} dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else [] @@ -810,6 +908,18 @@ class ExtractionOrchestrator: statement.triplet_extraction_info = triplet_map[statement.id] assigned_triplets += 1 + # 赋值情绪信息 + if statement.id in emotion_map: + emotion_data = emotion_map[statement.id] + if emotion_data is not None: + # 将EmotionExtraction对象的字段赋值到Statement + statement.emotion_type = emotion_data.emotion_type + statement.emotion_intensity = emotion_data.emotion_intensity + statement.emotion_keywords = emotion_data.emotion_keywords + statement.emotion_subject = emotion_data.emotion_subject + statement.emotion_target = emotion_data.emotion_target + assigned_emotions += 1 + # 赋值陈述句嵌入 if statement.id in statement_embedding_map: statement.statement_embedding = statement_embedding_map[statement.id] @@ -818,6 +928,7 @@ class ExtractionOrchestrator: logger.info( f"数据赋值完成 - 总陈述句: {total_statements}, " f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, " + f"情绪信息: {assigned_emotions}, " f"陈述句嵌入: {assigned_statement_embeddings}, " f"分块嵌入: {assigned_chunk_embeddings}, " f"对话嵌入: {assigned_dialog_embeddings}" @@ -927,6 +1038,12 @@ class ExtractionOrchestrator: created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, + # Emotion fields + emotion_type=getattr(statement, 'emotion_type', None), + emotion_intensity=getattr(statement, 'emotion_intensity', None), + emotion_keywords=getattr(statement, 'emotion_keywords', None), + emotion_subject=getattr(statement, 'emotion_subject', None), + emotion_target=getattr(statement, 'emotion_target', None), ) statement_nodes.append(statement_node) @@ -1333,7 +1450,7 @@ class ExtractionOrchestrator: if match: entity1_name = match.group(1).strip() entity1_type = match.group(2) - entity2_name = match.group(3).strip() + match.group(3).strip() entity2_type = match.group(4) # 提取置信度和原因 @@ -1646,7 +1763,6 @@ async def get_chunked_dialogs( """ import json import re - import os # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") @@ -1822,7 +1938,6 @@ async def get_chunked_dialogs_with_preprocessing( Returns: 带 chunks 的 DialogData 列表 """ - import os print("\n=== 完整数据处理流程(包含预处理)===") if input_data_path is None: diff --git a/api/app/core/memory/utils/config/overrides.py b/api/app/core/memory/utils/config/overrides.py index e333bb29..0dd7b2d1 100644 --- a/api/app/core/memory/utils/config/overrides.py +++ b/api/app/core/memory/utils/config/overrides.py @@ -28,7 +28,6 @@ """ import os import json -import socket from typing import Optional, Dict, Any, Literal NetworkMode = Literal['internal', 'external'] @@ -105,7 +104,6 @@ def _make_pgsql_conn() -> Optional[object]: try: import psycopg2 # type: ignore - from psycopg2.extras import RealDictCursor # type: ignore port = int(port_str) if port_str else 5432 conn = psycopg2.connect( @@ -193,7 +191,7 @@ def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, An # config_id 在数据库中是 Integer 类型,需要转换 try: config_id_int = int(config_id) - except (ValueError, TypeError) as e: + except (ValueError, TypeError): try: pass except Exception: @@ -207,7 +205,7 @@ def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, An " statement_granularity, include_dialogue_context, max_context, " " \"offset\" AS offset, lambda_time, lambda_mem, " " pruning_enabled, pruning_scene, pruning_threshold, " - " llm_id, embedding_id " + " llm_id, embedding_id, rerank_id " "FROM data_config WHERE config_id = %s LIMIT 1" ) cur.execute(sql, (config_id_int,)) @@ -222,7 +220,7 @@ def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, An pass return row if row else None - except Exception as e: + except Exception: pass return None finally: @@ -325,7 +323,7 @@ def _apply_overrides_from_db_row( _set_if_present(selections, tk, db_row, tk, str) # 特殊处理 UUID 字段,确保转换为字符串格式 - for uuid_field in ("llm_id", "embedding_id"): + for uuid_field in ("llm_id", "embedding_id", "rerank_id"): if uuid_field in db_row and db_row.get(uuid_field) is not None: try: value = db_row.get(uuid_field) @@ -370,7 +368,7 @@ def _apply_overrides_from_db_row( pass return runtime_cfg - except Exception as e: + except Exception: pass return runtime_cfg @@ -460,7 +458,7 @@ def apply_runtime_overrides_with_config_id( updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id") return updated_cfg, True - except Exception as e: + except Exception: pass return runtime_cfg, False @@ -570,7 +568,7 @@ def load_unified_config( try: with open(runtime_config_path, "r", encoding="utf-8") as f: runtime_cfg = json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: + except (FileNotFoundError, json.JSONDecodeError): runtime_cfg = {"selections": {}} # 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级) @@ -603,7 +601,7 @@ def load_unified_config( pass return runtime_cfg - except Exception as e: + except Exception: return {"selections": {}} diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 77a23e0f..c39a3f89 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -238,3 +238,81 @@ async def render_memory_summary_prompt( 'json_schema': 'MemorySummaryResponse.schema' }) return rendered_prompt + +async def render_emotion_extraction_prompt( + statement: str, + extract_keywords: bool, + enable_subject: bool +) -> str: + """ + Renders the emotion extraction prompt using the extract_emotion.jinja2 template. + + Args: + statement: The statement to analyze + extract_keywords: Whether to extract emotion keywords + enable_subject: Whether to enable subject classification + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("extract_emotion.jinja2") + rendered_prompt = template.render( + statement=statement, + extract_keywords=extract_keywords, + enable_subject=enable_subject + ) + + # 记录渲染结果到提示日志 + log_prompt_rendering('emotion extraction', rendered_prompt) + # 可选:记录模板渲染信息 + log_template_rendering('extract_emotion.jinja2', { + 'statement': 'str', + 'extract_keywords': extract_keywords, + 'enable_subject': enable_subject + }) + + return rendered_prompt + +async def render_emotion_suggestions_prompt( + health_data: dict, + patterns: dict, + user_profile: dict +) -> str: + """ + Renders the emotion suggestions generation prompt using the generate_emotion_suggestions.jinja2 template. + + Args: + health_data: 情绪健康数据 + patterns: 情绪模式分析结果 + user_profile: 用户画像数据 + + Returns: + Rendered prompt content as string + """ + import json + + # 预处理 emotion_distribution 为 JSON 字符串 + emotion_distribution_json = json.dumps( + health_data.get('emotion_distribution', {}), + ensure_ascii=False, + indent=2 + ) + + template = prompt_env.get_template("generate_emotion_suggestions.jinja2") + rendered_prompt = template.render( + health_data=health_data, + patterns=patterns, + user_profile=user_profile, + emotion_distribution_json=emotion_distribution_json + ) + + # 记录渲染结果到提示日志 + log_prompt_rendering('emotion suggestions', rendered_prompt) + # 可选:记录模板渲染信息 + log_template_rendering('generate_emotion_suggestions.jinja2', { + 'health_score': health_data.get('health_score'), + 'health_level': health_data.get('level'), + 'user_interests': user_profile.get('interests', []) + }) + + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/extract_emotion.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_emotion.jinja2 new file mode 100644 index 00000000..5e1e425f --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/extract_emotion.jinja2 @@ -0,0 +1,57 @@ +你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。 + +陈述句:{{ statement }} + +请提取以下信息: + +1. emotion_type(情绪类型): + - joy: 喜悦、开心、高兴、满意、愉快 + - sadness: 悲伤、难过、失落、沮丧、遗憾 + - anger: 愤怒、生气、不满、恼火、烦躁 + - fear: 恐惧、害怕、担心、焦虑、紧张 + - surprise: 惊讶、意外、震惊、吃惊 + - neutral: 中性、客观陈述、无明显情绪 + +2. emotion_intensity(情绪强度): + - 0.0-0.3: 弱情绪 + - 0.3-0.7: 中等情绪 + - 0.7-1.0: 强情绪 + +{% if extract_keywords %} +3. emotion_keywords(情绪关键词): + - 原句中直接表达情绪的词语 + - 最多提取3个关键词 + - 如果没有明显的情绪词,返回空列表 +{% else %} +3. emotion_keywords(情绪关键词): + - 返回空列表 +{% endif %} + +{% if enable_subject %} +4. emotion_subject(情绪主体): + - self: 用户本人的情绪(包含"我"、"我们"、"咱们"等第一人称) + - other: 他人的情绪(包含人名、"他/她"等第三人称) + - object: 对事物的评价(针对产品、地点、事件等) + + 注意: + - 如果同时包含多个主体,优先识别用户本人(self) + - 如果无法明确判断主体,默认为 self + +5. emotion_target(情绪对象): + - 如果有明确的情绪对象,提取其名称 + - 如果没有明确对象,返回 null +{% else %} +4. emotion_subject(情绪主体): + - 默认为 self + +5. emotion_target(情绪对象): + - 返回 null +{% endif %} + +注意事项: +- 如果陈述句是客观事实陈述,无明显情绪,标记为 neutral +- 情绪强度要符合语境,不要过度解读 +- 情绪关键词要准确,不要添加原句中没有的词 +- 主体分类要准确,优先识别用户本人(self) + +请以 JSON 格式返回结果。 diff --git a/api/app/core/memory/utils/prompt/prompts/generate_emotion_suggestions.jinja2 b/api/app/core/memory/utils/prompt/prompts/generate_emotion_suggestions.jinja2 new file mode 100644 index 00000000..6a29edd9 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/generate_emotion_suggestions.jinja2 @@ -0,0 +1,63 @@ +你是一位专业的心理健康顾问。请根据以下用户的情绪健康数据和个人信息,生成3-5条个性化的情绪改善建议。 + +## 用户情绪健康数据 + +健康分数:{{ health_data.health_score }}/100 +健康等级:{{ health_data.level }} + +维度分析: +- 积极率:{{ health_data.dimensions.positivity_rate.score }}/100 + - 正面情绪:{{ health_data.dimensions.positivity_rate.positive_count }}次 + - 负面情绪:{{ health_data.dimensions.positivity_rate.negative_count }}次 + - 中性情绪:{{ health_data.dimensions.positivity_rate.neutral_count }}次 + +- 稳定性:{{ health_data.dimensions.stability.score }}/100 + - 标准差:{{ health_data.dimensions.stability.std_deviation }} + +- 恢复力:{{ health_data.dimensions.resilience.score }}/100 + - 恢复率:{{ health_data.dimensions.resilience.recovery_rate }} + +情绪分布: +{{ emotion_distribution_json }} + +## 情绪模式分析 + +主要负面情绪:{{ patterns.dominant_negative_emotion|default('无') }} +情绪波动性:{{ patterns.emotion_volatility|default('未知') }} +高强度情绪次数:{{ patterns.high_intensity_emotions|default([])|length }} + +## 用户兴趣 + +{{ user_profile.interests|default(['未知'])|join(', ') }} + +## 任务要求 + +请生成3-5条个性化建议,每条建议包含: +1. type: 建议类型(emotion_balance/activity_recommendation/social_connection/stress_management) +2. title: 建议标题(简短有力) +3. content: 建议内容(详细说明,50-100字) +4. priority: 优先级(high/medium/low) +5. actionable_steps: 3个可执行的具体步骤 + +同时提供一个health_summary(不超过50字),概括用户的整体情绪状态。 + +请以JSON格式返回,格式如下: +{ + "health_summary": "您的情绪健康状况...", + "suggestions": [ + { + "type": "emotion_balance", + "title": "建议标题", + "content": "建议内容...", + "priority": "high", + "actionable_steps": ["步骤1", "步骤2", "步骤3"] + } + ] +} + +注意事项: +- 建议要具体、可执行,避免空泛 +- 结合用户的兴趣爱好提供个性化建议 +- 针对主要问题(如主要负面情绪)提供针对性建议 +- 优先级要合理分配(至少1个high,1-2个medium,其余low) +- 每个建议的3个步骤要循序渐进、易于实施 diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py index be43bd8d..870d46b2 100644 --- a/api/app/models/data_config_model.py +++ b/api/app/models/data_config_model.py @@ -64,7 +64,14 @@ class DataConfig(Base): lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") - + + # 情绪引擎配置 + emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取") + emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID") + emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词") + emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值") + emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类") + # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index d339879f..ce4a6876 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -100,7 +100,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC # "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [], # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), - "statement_embedding": statement.statement_embedding if statement.statement_embedding else None + "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, + # 添加情绪字段处理 + "emotion_type": statement.emotion_type, + "emotion_intensity": statement.emotion_intensity, + "emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [], + "emotion_subject": statement.emotion_subject, + "emotion_target": statement.emotion_target } flattened_statements.append(flattened_statement) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 95e2ee03..0f6e32aa 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -20,20 +20,25 @@ UNWIND $statements AS statement MERGE (s:Statement {id: statement.id}) SET s += { id: statement.id, + run_id: statement.run_id, + chunk_id: statement.chunk_id, group_id: statement.group_id, user_id: statement.user_id, apply_id: statement.apply_id, - chunk_id: statement.chunk_id, - run_id: statement.run_id, + stmt_type: statement.stmt_type, + statement: statement.statement, + emotion_intensity: statement.emotion_intensity, + emotion_target: statement.emotion_target, + emotion_subject: statement.emotion_subject, + emotion_type: statement.emotion_type, + emotion_keywords: statement.emotion_keywords, + temporal_info: statement.temporal_info, created_at: statement.created_at, expired_at: statement.expired_at, - stmt_type: statement.stmt_type, - temporal_info: statement.temporal_info, - relevence_info: statement.relevence_info, - statement: statement.statement, valid_at: statement.valid_at, invalid_at: statement.invalid_at, - statement_embedding: statement.statement_embedding + statement_embedding: statement.statement_embedding, + relevence_info: statement.relevence_info } RETURN s.id AS uuid """ diff --git a/api/app/repositories/neo4j/emotion_repository.py b/api/app/repositories/neo4j/emotion_repository.py new file mode 100644 index 00000000..d445c8d4 --- /dev/null +++ b/api/app/repositories/neo4j/emotion_repository.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +"""情绪数据仓储模块 + +本模块提供情绪数据的查询功能,用于情绪分析和统计。 + +Classes: + EmotionRepository: 情绪数据仓储,提供情绪标签、词云、健康指数等查询方法 +""" + +from typing import List, Dict, Optional, Any +from datetime import datetime, timedelta +import json + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class EmotionRepository: + """情绪数据仓储 + + 提供情绪数据的查询和统计功能,包括: + - 情绪标签统计 + - 情绪词云数据 + - 时间范围内的情绪数据查询 + + Attributes: + connector: Neo4j连接器实例 + """ + + def __init__(self, connector: Neo4jConnector): + """初始化情绪数据仓储 + + Args: + connector: Neo4j连接器实例 + """ + self.connector = connector + logger.info("情绪数据仓储初始化完成") + + async def get_emotion_tags( + self, + group_id: str, + emotion_type: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 10 + ) -> List[Dict[str, Any]]: + """获取情绪标签统计 + + 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 + + Args: + group_id: 用户组ID(宿主ID) + emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral) + start_date: 可选的开始日期(ISO格式字符串) + end_date: 可选的结束日期(ISO格式字符串) + limit: 返回结果的最大数量 + + Returns: + List[Dict]: 情绪标签列表,每个包含: + - emotion_type: 情绪类型 + - count: 该类型的数量 + - percentage: 占比百分比 + - avg_intensity: 平均强度 + """ + # 构建查询条件 + where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"] + params = {"group_id": group_id, "limit": limit} + + if emotion_type: + where_clauses.append("s.emotion_type = $emotion_type") + params["emotion_type"] = emotion_type + + if start_date: + where_clauses.append("s.created_at >= $start_date") + params["start_date"] = start_date + + if end_date: + where_clauses.append("s.created_at <= $end_date") + params["end_date"] = end_date + + where_str = " AND ".join(where_clauses) + + # 优化的 Cypher 查询:使用索引,减少中间结果 + query = f""" + MATCH (s:Statement) + WHERE {where_str} + WITH s.emotion_type as emotion_type, + count(*) as count, + avg(s.emotion_intensity) as avg_intensity + WITH collect({{emotion_type: emotion_type, count: count, avg_intensity: avg_intensity}}) as results, + sum(count) as total_count + UNWIND results as result + RETURN result.emotion_type as emotion_type, + result.count as count, + toFloat(result.count) / total_count * 100 as percentage, + result.avg_intensity as avg_intensity + ORDER BY count DESC + LIMIT $limit + """ + + try: + results = await self.connector.execute_query(query, **params) + formatted_results = [ + { + "emotion_type": record["emotion_type"], + "count": record["count"], + "percentage": round(record["percentage"], 2), + "avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0 + } + for record in results + ] + + return formatted_results + except Exception as e: + logger.error(f"查询情绪标签失败: {str(e)}", exc_info=True) + return [] + + async def get_emotion_wordcloud( + self, + group_id: str, + emotion_type: Optional[str] = None, + limit: int = 50 + ) -> List[Dict[str, Any]]: + """获取情绪词云数据 + + 查询情绪关键词及其频率,用于生成词云可视化。 + + Args: + group_id: 用户组ID(宿主ID) + emotion_type: 可选的情绪类型过滤 + limit: 返回关键词的最大数量 + + Returns: + List[Dict]: 关键词列表,每个包含: + - keyword: 关键词 + - frequency: 出现频率 + - emotion_type: 关联的情绪类型 + - avg_intensity: 平均强度 + """ + # 构建查询条件 + where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"] + params = {"group_id": group_id, "limit": limit} + + if emotion_type: + where_clauses.append("s.emotion_type = $emotion_type") + params["emotion_type"] = emotion_type + + where_str = " AND ".join(where_clauses) + + # 优化的 Cypher 查询:使用索引,减少不必要的计算 + query = f""" + MATCH (s:Statement) + WHERE {where_str} + UNWIND s.emotion_keywords as keyword + WITH keyword, + s.emotion_type as emotion_type, + count(*) as frequency, + avg(s.emotion_intensity) as avg_intensity + WHERE keyword IS NOT NULL AND keyword <> '' + RETURN keyword, + frequency, + emotion_type, + avg_intensity + ORDER BY frequency DESC + LIMIT $limit + """ + + try: + results = await self.connector.execute_query(query, **params) + formatted_results = [ + { + "keyword": record["keyword"], + "frequency": record["frequency"], + "emotion_type": record["emotion_type"], + "avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0 + } + for record in results + ] + + return formatted_results + except Exception as e: + logger.error(f"查询情绪词云失败: {str(e)}", exc_info=True) + return [] + + async def get_emotions_in_range( + self, + group_id: str, + time_range: str = "30d" + ) -> List[Dict[str, Any]]: + """获取时间范围内的情绪数据 + + 查询指定时间范围内的所有情绪数据,用于健康指数计算。 + + Args: + group_id: 用户组ID(宿主ID) + time_range: 时间范围(7d/30d/90d) + + Returns: + List[Dict]: 情绪数据列表,每个包含: + - emotion_type: 情绪类型 + - emotion_intensity: 情绪强度 + - created_at: 创建时间 + - statement_id: 陈述句ID + """ + # 解析时间范围 + days_map = {"7d": 7, "30d": 30, "90d": 90} + days = days_map.get(time_range, 30) + + # 计算起始日期(使用字符串比较,避免时区问题) + start_date = (datetime.now() - timedelta(days=days)).isoformat() + + # 优化的 Cypher 查询:使用字符串比较避免时区问题 + query = """ + MATCH (s:Statement) + WHERE s.group_id = $group_id + AND s.emotion_type IS NOT NULL + AND s.created_at >= $start_date + RETURN s.id as statement_id, + s.emotion_type as emotion_type, + s.emotion_intensity as emotion_intensity, + s.created_at as created_at + ORDER BY s.created_at ASC + """ + + try: + results = await self.connector.execute_query( + query, + group_id=group_id, + start_date=start_date + ) + formatted_results = [ + { + "statement_id": record["statement_id"], + "emotion_type": record["emotion_type"], + "emotion_intensity": record["emotion_intensity"], + "created_at": record["created_at"].isoformat() if hasattr(record["created_at"], "isoformat") else str(record["created_at"]) + } + for record in results + ] + + return formatted_results + except Exception as e: + logger.error(f"查询时间范围情绪数据失败: {str(e)}", exc_info=True) + return [] diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index ec2d6660..34858444 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -58,11 +58,22 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): n['invalid_at'] = datetime.fromisoformat(n['invalid_at']) # 处理temporal_info字段 - if isinstance(n.get('temporal_info'), dict): + if isinstance(n.get('temporal_info'), str): + # 从字符串转换为枚举值 + n['temporal_info'] = TemporalInfo(n['temporal_info']) + elif isinstance(n.get('temporal_info'), dict): n['temporal_info'] = TemporalInfo(**n['temporal_info']) elif not n.get('temporal_info'): # 如果没有temporal_info,创建一个默认的 - n['temporal_info'] = TemporalInfo() + n['temporal_info'] = TemporalInfo.STATIC + + # 处理情绪字段 - 映射 Neo4j 节点属性到 StatementNode 模型 + # 处理空值情况,确保字段存在 + n['emotion_type'] = n.get('emotion_type') + n['emotion_intensity'] = n.get('emotion_intensity') + n['emotion_keywords'] = n.get('emotion_keywords', []) + n['emotion_subject'] = n.get('emotion_subject') + n['emotion_target'] = n.get('emotion_target') return StatementNode(**n) diff --git a/api/app/schemas/emotion_schema.py b/api/app/schemas/emotion_schema.py new file mode 100644 index 00000000..9f14884d --- /dev/null +++ b/api/app/schemas/emotion_schema.py @@ -0,0 +1,32 @@ +"""情绪分析相关的请求和响应模型""" + +from typing import Optional +from pydantic import BaseModel, Field + + +class EmotionTagsRequest(BaseModel): + """获取情绪标签统计请求""" + group_id: str = Field(..., description="组ID") + emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") + start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)") + end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)") + limit: int = Field(10, ge=1, le=100, description="返回数量限制") + + +class EmotionWordcloudRequest(BaseModel): + """获取情绪词云数据请求""" + group_id: str = Field(..., description="组ID") + emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") + limit: int = Field(50, ge=1, le=200, description="返回词语数量") + + +class EmotionHealthRequest(BaseModel): + """获取情绪健康指数请求""" + group_id: str = Field(..., description="组ID") + time_range: str = Field("30d", description="时间范围(7d/30d/90d)") + + +class EmotionSuggestionsRequest(BaseModel): + """获取个性化情绪建议请求""" + group_id: str = Field(..., description="组ID") + config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)") diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py new file mode 100644 index 00000000..6952256e --- /dev/null +++ b/api/app/services/emotion_analytics_service.py @@ -0,0 +1,670 @@ +# -*- coding: utf-8 -*- +"""情绪分析服务模块 + +本模块提供情绪数据的分析和统计功能,包括情绪标签、词云、健康指数计算等。 + +Classes: + EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能 +""" + +from typing import Dict, Any, Optional, List +import statistics +import json +from pydantic import BaseModel, Field + +from app.repositories.neo4j.emotion_repository import EmotionRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class EmotionSuggestion(BaseModel): + """情绪建议模型""" + type: str = Field(..., description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management") + title: str = Field(..., description="建议标题") + content: str = Field(..., description="建议内容") + priority: str = Field(..., description="优先级:high/medium/low") + actionable_steps: List[str] = Field(..., description="可执行步骤列表(3个)") + + +class EmotionSuggestionsResponse(BaseModel): + """情绪建议响应模型""" + health_summary: str = Field(..., description="健康状态摘要(不超过50字)") + suggestions: List[EmotionSuggestion] = Field(..., description="建议列表(3-5条)") + + +class EmotionAnalyticsService: + """情绪分析服务 + + 提供情绪数据的分析和统计功能,包括: + - 情绪标签统计 + - 情绪词云数据 + - 情绪健康指数计算 + - 个性化情绪建议生成 + + Attributes: + emotion_repo: 情绪数据仓储实例 + """ + + def __init__(self): + """初始化情绪分析服务""" + connector = Neo4jConnector() + self.emotion_repo = EmotionRepository(connector) + logger.info("情绪分析服务初始化完成") + + async def get_emotion_tags( + self, + end_user_id: str, + emotion_type: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 10 + ) -> Dict[str, Any]: + """获取情绪标签统计 + + 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 + + Args: + end_user_id: 宿主ID(用户组ID) + emotion_type: 可选的情绪类型过滤 + start_date: 可选的开始日期(ISO格式) + end_date: 可选的结束日期(ISO格式) + limit: 返回结果的最大数量 + + Returns: + Dict: 包含情绪标签统计的响应数据: + - tags: 情绪标签列表 + - total_count: 总情绪数量 + - time_range: 时间范围信息 + """ + try: + logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, " + f"start={start_date}, end={end_date}, limit={limit}") + + # 调用仓储层查询 + tags = await self.emotion_repo.get_emotion_tags( + group_id=end_user_id, + emotion_type=emotion_type, + start_date=start_date, + end_date=end_date, + limit=limit + ) + + # 计算总数 + total_count = sum(tag["count"] for tag in tags) + + # 构建时间范围信息 + time_range = {} + if start_date: + time_range["start_date"] = start_date + if end_date: + time_range["end_date"] = end_date + + # 格式化响应 + response = { + "tags": tags, + "total_count": total_count, + "time_range": time_range if time_range else None + } + + logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(tags)}") + return response + + except Exception as e: + logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True) + raise + + async def get_emotion_wordcloud( + self, + end_user_id: str, + emotion_type: Optional[str] = None, + limit: int = 50 + ) -> Dict[str, Any]: + """获取情绪词云数据 + + 查询情绪关键词及其频率,用于生成词云可视化。 + + Args: + end_user_id: 宿主ID(用户组ID) + emotion_type: 可选的情绪类型过滤 + limit: 返回关键词的最大数量 + + Returns: + Dict: 包含情绪词云数据的响应: + - keywords: 关键词列表 + - total_keywords: 总关键词数量 + """ + try: + logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}") + + # 调用仓储层查询 + keywords = await self.emotion_repo.get_emotion_wordcloud( + group_id=end_user_id, + emotion_type=emotion_type, + limit=limit + ) + + # 计算总关键词数量 + total_keywords = len(keywords) + + # 格式化响应 + response = { + "keywords": keywords, + "total_keywords": total_keywords + } + + logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}") + return response + + except Exception as e: + logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True) + raise + + def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: + """计算积极率 + + 根据情绪类型分类正面、负面和中性情绪,计算积极率。 + 公式:(正面数 / (正面数 + 负面数)) * 100 + + Args: + emotions: 情绪数据列表,每个包含 emotion_type 字段 + + Returns: + Dict: 包含积极率计算结果: + - score: 积极率分数(0-100) + - positive_count: 正面情绪数量 + - negative_count: 负面情绪数量 + - neutral_count: 中性情绪数量 + """ + # 定义情绪分类 + positive_emotions = {'joy', 'surprise'} + negative_emotions = {'sadness', 'anger', 'fear'} + + # 统计各类情绪数量 + positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions) + negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions) + neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral') + + # 计算积极率 + total_non_neutral = positive_count + negative_count + if total_non_neutral > 0: + score = (positive_count / total_non_neutral) * 100 + else: + score = 50.0 # 如果没有非中性情绪,默认为50 + + logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, " + f"neutral={neutral_count}, score={score:.2f}") + + return { + "score": round(score, 2), + "positive_count": positive_count, + "negative_count": negative_count, + "neutral_count": neutral_count + } + + def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: + """计算稳定性 + + 基于情绪强度的标准差计算情绪稳定性。 + 公式:(1 - min(std_deviation, 1.0)) * 100 + + Args: + emotions: 情绪数据列表,每个包含 emotion_intensity 字段 + + Returns: + Dict: 包含稳定性计算结果: + - score: 稳定性分数(0-100) + - std_deviation: 标准差 + """ + # 提取所有情绪强度 + intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None] + + # 计算标准差 + if len(intensities) >= 2: + std_deviation = statistics.stdev(intensities) + elif len(intensities) == 1: + std_deviation = 0.0 # 只有一个数据点,标准差为0 + else: + std_deviation = 0.0 # 没有数据,标准差为0 + + # 计算稳定性分数 + # 标准差越小,稳定性越高 + score = (1 - min(std_deviation, 1.0)) * 100 + + logger.debug(f"稳定性计算: intensities_count={len(intensities)}, " + f"std_deviation={std_deviation:.3f}, score={score:.2f}") + + return { + "score": round(score, 2), + "std_deviation": round(std_deviation, 3) + } + + def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: + """计算恢复力 + + 分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。 + 公式:(负面到正面转换次数 / 总负面情绪数) * 100 + + Args: + emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段 + 应该按时间顺序排列 + + Returns: + Dict: 包含恢复力计算结果: + - score: 恢复力分数(0-100) + - recovery_rate: 恢复率(转换次数/负面情绪数) + """ + # 定义情绪分类 + positive_emotions = {'joy', 'surprise'} + negative_emotions = {'sadness', 'anger', 'fear'} + + # 统计负面到正面的转换次数 + recovery_count = 0 + negative_count = 0 + + for i in range(len(emotions)): + current_emotion = emotions[i].get('emotion_type') + + # 统计负面情绪总数 + if current_emotion in negative_emotions: + negative_count += 1 + + # 检查下一个情绪是否为正面 + if i + 1 < len(emotions): + next_emotion = emotions[i + 1].get('emotion_type') + if next_emotion in positive_emotions: + recovery_count += 1 + + # 计算恢复力分数 + if negative_count > 0: + recovery_rate = recovery_count / negative_count + score = recovery_rate * 100 + else: + # 如果没有负面情绪,恢复力设为100(最佳状态) + recovery_rate = 1.0 + score = 100.0 + + logger.debug(f"恢复力计算: negative_count={negative_count}, " + f"recovery_count={recovery_count}, score={score:.2f}") + + return { + "score": round(score, 2), + "recovery_rate": round(recovery_rate, 3) + } + + async def calculate_emotion_health_index( + self, + end_user_id: str, + time_range: str = "30d" + ) -> Dict[str, Any]: + """计算情绪健康指数 + + 综合积极率、稳定性和恢复力计算情绪健康指数。 + + Args: + end_user_id: 宿主ID(用户组ID) + time_range: 时间范围(7d/30d/90d) + + Returns: + Dict: 包含情绪健康指数的完整响应: + - health_score: 综合健康分数(0-100) + - level: 健康等级(优秀/良好/一般/较差) + - dimensions: 各维度详细数据 + - positivity_rate: 积极率 + - stability: 稳定性 + - resilience: 恢复力 + - emotion_distribution: 情绪分布统计 + - time_range: 时间范围 + """ + try: + logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}") + + # 获取时间范围内的情绪数据 + emotions = await self.emotion_repo.get_emotions_in_range( + group_id=end_user_id, + time_range=time_range + ) + + # 如果没有数据,返回默认值 + if not emotions: + logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据") + return { + "health_score": 0.0, + "level": "无数据", + "dimensions": { + "positivity_rate": {"score": 0.0, "positive_count": 0, "negative_count": 0, "neutral_count": 0}, + "stability": {"score": 0.0, "std_deviation": 0.0}, + "resilience": {"score": 0.0, "recovery_rate": 0.0} + }, + "emotion_distribution": {}, + "time_range": time_range + } + + # 计算各维度指标 + positivity_rate = self._calculate_positivity_rate(emotions) + stability = self._calculate_stability(emotions) + resilience = self._calculate_resilience(emotions) + + # 计算综合健康分数 + # 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3 + health_score = ( + positivity_rate["score"] * 0.4 + + stability["score"] * 0.3 + + resilience["score"] * 0.3 + ) + + # 确定健康等级 + if health_score >= 80: + level = "优秀" + elif health_score >= 60: + level = "良好" + elif health_score >= 40: + level = "一般" + else: + level = "较差" + + # 统计情绪分布 + emotion_distribution = {} + for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']: + count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type) + emotion_distribution[emotion_type] = count + + # 格式化响应 + response = { + "health_score": round(health_score, 2), + "level": level, + "dimensions": { + "positivity_rate": positivity_rate, + "stability": stability, + "resilience": resilience + }, + "emotion_distribution": emotion_distribution, + "time_range": time_range + } + + logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}") + return response + + except Exception as e: + logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True) + raise + + def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: + """分析情绪模式 + + 识别主要负面情绪、情绪触发因素和波动时段。 + + Args: + emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段 + + Returns: + Dict: 包含情绪模式分析结果: + - dominant_negative_emotion: 主要负面情绪类型 + - high_intensity_emotions: 高强度情绪列表 + - emotion_volatility: 情绪波动性(高/中/低) + """ + negative_emotions = {'sadness', 'anger', 'fear'} + + # 统计负面情绪分布 + negative_emotion_counts = {} + for emotion in emotions: + emotion_type = emotion.get('emotion_type') + if emotion_type in negative_emotions: + negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1 + + # 识别主要负面情绪 + dominant_negative_emotion = None + if negative_emotion_counts: + dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get) + + # 识别高强度情绪(强度 >= 0.7) + high_intensity_emotions = [ + { + "type": e.get('emotion_type'), + "intensity": e.get('emotion_intensity'), + "created_at": e.get('created_at') + } + for e in emotions + if e.get('emotion_intensity', 0) >= 0.7 + ] + + # 评估情绪波动性 + intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None] + if len(intensities) >= 2: + std_dev = statistics.stdev(intensities) + if std_dev > 0.3: + volatility = "高" + elif std_dev > 0.15: + volatility = "中" + else: + volatility = "低" + else: + volatility = "未知" + + logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, " + f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}") + + return { + "dominant_negative_emotion": dominant_negative_emotion, + "high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个 + "emotion_volatility": volatility + } + + async def generate_emotion_suggestions( + self, + end_user_id: str, + config_id: Optional[int] = None + ) -> Dict[str, Any]: + """生成个性化情绪建议 + + 基于情绪健康数据和用户画像生成个性化建议。 + + Args: + end_user_id: 宿主ID(用户组ID) + config_id: 配置ID(可选,用于从数据库加载LLM配置) + + Returns: + Dict: 包含个性化建议的响应: + - health_summary: 健康状态摘要 + - suggestions: 建议列表(3-5条) + """ + try: + logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}") + + # 1. 如果提供了 config_id,从数据库加载配置 + if config_id is not None: + from app.core.memory.utils.config.definitions import reload_configuration_from_database + config_loaded = reload_configuration_from_database(config_id) + if not config_loaded: + logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置") + + # 2. 获取情绪健康数据 + health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d") + + # 3. 获取情绪数据用于模式分析 + emotions = await self.emotion_repo.get_emotions_in_range( + group_id=end_user_id, + time_range="30d" + ) + + # 4. 分析情绪模式 + patterns = self._analyze_emotion_patterns(emotions) + + # 5. 获取用户画像数据(简化版,直接从Neo4j获取) + user_profile = await self._get_simple_user_profile(end_user_id) + + # 6. 构建LLM prompt + prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile) + + # 7. 调用LLM生成建议(使用配置中的LLM) + from app.core.memory.utils.llm.llm_utils import get_llm_client + llm_client = get_llm_client() + + # 将 prompt 转换为 messages 格式 + messages = [ + {"role": "user", "content": prompt} + ] + + response = await llm_client.chat(messages=messages) + response_text = response.content.strip() + + # 8. 解析LLM响应 + try: + response_data = json.loads(response_text) + suggestions_response = EmotionSuggestionsResponse(**response_data) + except (json.JSONDecodeError, Exception) as e: + logger.error(f"解析LLM响应失败: {str(e)}, response={response_text}") + # 返回默认建议 + suggestions_response = self._get_default_suggestions(health_data) + + # 8. 验证建议数量(3-5条) + if len(suggestions_response.suggestions) < 3: + logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}") + suggestions_response = self._get_default_suggestions(health_data) + elif len(suggestions_response.suggestions) > 5: + logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}") + suggestions_response.suggestions = suggestions_response.suggestions[:5] + + # 9. 格式化响应 + response = { + "health_summary": suggestions_response.health_summary, + "suggestions": [ + { + "type": s.type, + "title": s.title, + "content": s.content, + "priority": s.priority, + "actionable_steps": s.actionable_steps + } + for s in suggestions_response.suggestions + ] + } + + logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}") + return response + + except Exception as e: + logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True) + raise + + async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]: + """获取简化的用户画像数据 + + Args: + end_user_id: 用户ID + + Returns: + Dict: 用户画像数据 + """ + try: + connector = Neo4jConnector() + + # 查询用户的实体和标签 + query = """ + MATCH (e:Entity) + WHERE e.group_id = $group_id + RETURN e.name as name, e.type as type + ORDER BY e.created_at DESC + LIMIT 20 + """ + + entities = await connector.execute_query(query, group_id=end_user_id) + + # 提取兴趣标签 + interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] + # 后期会引入用户的习惯。。 + return { + "interests": interests if interests else ["未知"] + } + + except Exception as e: + logger.error(f"获取用户画像失败: {str(e)}") + return {"interests": ["未知"]} + + async def _build_suggestion_prompt( + self, + health_data: Dict[str, Any], + patterns: Dict[str, Any], + user_profile: Dict[str, Any] + ) -> str: + """构建情绪建议生成的prompt + + Args: + health_data: 情绪健康数据 + patterns: 情绪模式分析结果 + user_profile: 用户画像数据 + + Returns: + str: LLM prompt + """ + from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt + + prompt = await render_emotion_suggestions_prompt( + health_data=health_data, + patterns=patterns, + user_profile=user_profile + ) + + return prompt + + def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse: + """获取默认建议(当LLM调用失败时使用) + + Args: + health_data: 情绪健康数据 + + Returns: + EmotionSuggestionsResponse: 默认建议 + """ + health_score = health_data.get('health_score', 0) + + if health_score >= 80: + summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。" + elif health_score >= 60: + summary = "您的情绪健康状况良好,可以通过一些调整进一步提升。" + elif health_score >= 40: + summary = "您的情绪健康需要关注,建议采取一些改善措施。" + else: + summary = "您的情绪健康需要重点关注,建议寻求专业帮助。" + + suggestions = [ + EmotionSuggestion( + type="emotion_balance", + title="保持情绪平衡", + content="通过正念冥想和深呼吸练习,帮助您更好地管理情绪波动,提升情绪稳定性。", + priority="high", + actionable_steps=[ + "每天早晨进行5-10分钟的正念冥想", + "感到情绪波动时,进行3次深呼吸", + "记录每天的情绪变化,识别触发因素" + ] + ), + EmotionSuggestion( + type="activity_recommendation", + title="增加户外活动", + content="适度的户外运动可以有效改善情绪,增强身心健康。建议每周进行3-4次户外活动。", + priority="medium", + actionable_steps=[ + "每周安排2-3次30分钟的散步", + "周末尝试户外运动如骑行或爬山", + "在户外活动时关注周围环境,放松心情" + ] + ), + EmotionSuggestion( + type="social_connection", + title="加强社交联系", + content="与朋友和家人保持良好的社交联系,可以提供情感支持,改善情绪健康。", + priority="medium", + actionable_steps=[ + "每周至少与一位朋友或家人深入交流", + "参加感兴趣的社交活动或兴趣小组", + "主动分享自己的感受和想法" + ] + ) + ] + + return EmotionSuggestionsResponse( + health_summary=summary, + suggestions=suggestions + ) diff --git a/api/app/services/emotion_config_service.py b/api/app/services/emotion_config_service.py new file mode 100644 index 00000000..37171640 --- /dev/null +++ b/api/app/services/emotion_config_service.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +"""情绪配置服务模块 + +本模块提供情绪引擎配置的管理功能,包括获取和更新配置。 + +Classes: + EmotionConfigService: 情绪配置服务,提供配置管理功能 +""" + +from typing import Dict, Any +from sqlalchemy.orm import Session + +from app.models.data_config_model import DataConfig +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class EmotionConfigService: + """情绪配置服务 + + 提供情绪引擎配置的管理功能,包括: + - 获取情绪配置 + - 更新情绪配置 + - 验证配置参数 + + Attributes: + db: 数据库会话 + """ + + def __init__(self, db: Session): + """初始化情绪配置服务 + + Args: + db: 数据库会话 + """ + self.db = db + logger.info("情绪配置服务初始化完成") + + def get_emotion_config(self, config_id: int) -> Dict[str, Any]: + """获取情绪引擎配置 + + 查询指定配置ID的情绪相关配置字段。 + + Args: + config_id: 配置ID + + Returns: + Dict: 包含情绪配置的响应数据: + - config_id: 配置ID + - emotion_enabled: 是否启用情绪提取 + - emotion_model_id: 情绪分析专用模型ID + - emotion_extract_keywords: 是否提取情绪关键词 + - emotion_min_intensity: 最小情绪强度阈值 + - emotion_enable_subject: 是否启用主体分类 + + Raises: + ValueError: 当配置不存在时 + """ + try: + logger.info(f"获取情绪配置: config_id={config_id}") + + # 查询配置 + config = self.db.query(DataConfig).filter( + DataConfig.config_id == config_id + ).first() + + if not config: + logger.error(f"配置不存在: config_id={config_id}") + raise ValueError(f"配置不存在: config_id={config_id}") + + # 提取情绪相关字段 + emotion_config = { + "config_id": config.config_id, + "emotion_enabled": config.emotion_enabled, + "emotion_model_id": config.emotion_model_id, + "emotion_extract_keywords": config.emotion_extract_keywords, + "emotion_min_intensity": config.emotion_min_intensity, + "emotion_enable_subject": config.emotion_enable_subject + } + + logger.info(f"情绪配置获取成功: config_id={config_id}") + return emotion_config + + except ValueError: + raise + except Exception as e: + logger.error(f"获取情绪配置失败: {str(e)}", exc_info=True) + raise + + def validate_emotion_config(self, config_data: Dict[str, Any]) -> bool: + """验证情绪配置参数 + + 验证配置参数的有效性,包括: + - emotion_min_intensity 在 [0.0, 1.0] 范围内 + - 布尔字段类型正确 + - emotion_model_id 格式有效(如果提供) + + Args: + config_data: 配置数据字典 + + Returns: + bool: 验证是否通过 + + Raises: + ValueError: 当配置参数无效时 + """ + try: + logger.debug(f"验证情绪配置参数: {config_data}") + + # 验证 emotion_min_intensity 范围 + if "emotion_min_intensity" in config_data: + min_intensity = config_data["emotion_min_intensity"] + if not isinstance(min_intensity, (int, float)): + raise ValueError("emotion_min_intensity 必须是数字类型") + if not (0.0 <= min_intensity <= 1.0): + raise ValueError("emotion_min_intensity 必须在 0.0 到 1.0 之间") + + # 验证布尔字段 + bool_fields = ["emotion_enabled", "emotion_extract_keywords", "emotion_enable_subject"] + for field in bool_fields: + if field in config_data: + value = config_data[field] + if not isinstance(value, bool): + raise ValueError(f"{field} 必须是布尔类型") + + # 验证 emotion_model_id(如果提供) + if "emotion_model_id" in config_data: + model_id = config_data["emotion_model_id"] + if model_id is not None and not isinstance(model_id, str): + raise ValueError("emotion_model_id 必须是字符串类型或 null") + if model_id is not None and len(model_id.strip()) == 0: + raise ValueError("emotion_model_id 不能为空字符串") + + logger.debug("情绪配置参数验证通过") + return True + + except ValueError as e: + logger.warning(f"配置参数验证失败: {str(e)}") + raise + except Exception as e: + logger.error(f"验证配置参数时发生错误: {str(e)}", exc_info=True) + raise ValueError(f"验证配置参数失败: {str(e)}") + + def update_emotion_config( + self, + config_id: int, + config_data: Dict[str, Any] + ) -> Dict[str, Any]: + """更新情绪引擎配置 + + 更新指定配置ID的情绪相关配置字段。 + + Args: + config_id: 配置ID + config_data: 要更新的配置数据,可包含以下字段: + - emotion_enabled: 是否启用情绪提取 + - emotion_model_id: 情绪分析专用模型ID + - emotion_extract_keywords: 是否提取情绪关键词 + - emotion_min_intensity: 最小情绪强度阈值 + - emotion_enable_subject: 是否启用主体分类 + + Returns: + Dict: 更新后的完整情绪配置 + + Raises: + ValueError: 当配置不存在或参数无效时 + """ + try: + logger.info(f"更新情绪配置: config_id={config_id}, data={config_data}") + + # 验证配置参数 + self.validate_emotion_config(config_data) + + # 查询配置 + config = self.db.query(DataConfig).filter( + DataConfig.config_id == config_id + ).first() + + if not config: + logger.error(f"配置不存在: config_id={config_id}") + raise ValueError(f"配置不存在: config_id={config_id}") + + # 更新字段 + if "emotion_enabled" in config_data: + config.emotion_enabled = config_data["emotion_enabled"] + if "emotion_model_id" in config_data: + config.emotion_model_id = config_data["emotion_model_id"] + if "emotion_extract_keywords" in config_data: + config.emotion_extract_keywords = config_data["emotion_extract_keywords"] + if "emotion_min_intensity" in config_data: + config.emotion_min_intensity = config_data["emotion_min_intensity"] + if "emotion_enable_subject" in config_data: + config.emotion_enable_subject = config_data["emotion_enable_subject"] + + # 提交更改 + self.db.commit() + self.db.refresh(config) + + # 返回更新后的配置 + updated_config = self.get_emotion_config(config_id) + + logger.info(f"情绪配置更新成功: config_id={config_id}") + return updated_config + + except ValueError: + self.db.rollback() + raise + except Exception as e: + self.db.rollback() + logger.error(f"更新情绪配置失败: {str(e)}", exc_info=True) + raise diff --git a/api/app/services/emotion_extraction_service.py b/api/app/services/emotion_extraction_service.py new file mode 100644 index 00000000..b3172df1 --- /dev/null +++ b/api/app/services/emotion_extraction_service.py @@ -0,0 +1,200 @@ +"""Emotion extraction service for analyzing emotions from statements. + +This service extracts emotion information from user statements using LLM, +including emotion type, intensity, keywords, subject classification, and target. + +Classes: + EmotionExtractionService: Service for extracting emotions from statements +""" + +import logging +from typing import Optional +from app.core.memory.models.emotion_models import EmotionExtraction +from app.models.data_config_model import DataConfig +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.llm_tools.llm_client import LLMClientException + +logger = logging.getLogger(__name__) + + +class EmotionExtractionService: + """Service for extracting emotion information from statements. + + This service uses LLM to analyze statements and extract structured emotion + information including type, intensity, keywords, subject, and target. + It respects configuration settings for enabling/disabling extraction and + filtering by intensity threshold. + + Attributes: + llm_client: LLM client for making structured output calls + """ + + def __init__(self, llm_id: Optional[str] = None): + """Initialize the emotion extraction service. + + Args: + llm_id: Optional LLM model ID. If None, uses default from config. + """ + self.llm_client = None + self.llm_id = llm_id + logger.info(f"Initialized EmotionExtractionService with llm_id={llm_id}") + + def _get_llm_client(self, model_id: Optional[str] = None): + """Get or create LLM client instance. + + Args: + model_id: Optional model ID to use. If None, uses instance llm_id. + + Returns: + LLM client instance + """ + if self.llm_client is None or model_id: + effective_model_id = model_id or self.llm_id + self.llm_client = get_llm_client(effective_model_id) + return self.llm_client + + async def extract_emotion( + self, + statement: str, + config: DataConfig + ) -> Optional[EmotionExtraction]: + """Extract emotion information from a statement. + + This method checks if emotion extraction is enabled in the config, + builds an appropriate prompt, calls the LLM for structured output, + and applies intensity threshold filtering. + + Args: + statement: The statement text to analyze + config: Data configuration object containing emotion settings + + Returns: + EmotionExtraction object if extraction succeeds and passes threshold, + None if extraction is disabled, fails, or doesn't meet threshold + + Raises: + No exceptions are raised - failures are logged and return None + """ + # Check if emotion extraction is enabled + if not config.emotion_enabled: + logger.debug("Emotion extraction is disabled in config") + return None + + # Validate statement + if not statement or not statement.strip(): + logger.warning("Empty statement provided for emotion extraction") + return None + + try: + # Build the emotion extraction prompt + prompt = await self._build_emotion_prompt( + statement=statement, + extract_keywords=config.emotion_extract_keywords, + enable_subject=config.emotion_enable_subject + ) + + # Call LLM for structured output + emotion = await self._call_llm_structured( + prompt=prompt, + model_id=config.emotion_model_id + ) + + # Apply intensity threshold filtering + if emotion.emotion_intensity < config.emotion_min_intensity: + logger.debug( + f"Emotion intensity {emotion.emotion_intensity} below threshold " + f"{config.emotion_min_intensity}, skipping storage" + ) + return None + + logger.info( + f"Successfully extracted emotion: type={emotion.emotion_type}, " + f"intensity={emotion.emotion_intensity}, subject={emotion.emotion_subject}" + ) + + return emotion + + except Exception as e: + logger.error( + f"Emotion extraction failed for statement: {statement[:50]}..., " + f"error: {str(e)}", + exc_info=True + ) + return None + + async def _build_emotion_prompt( + self, + statement: str, + extract_keywords: bool, + enable_subject: bool + ) -> str: + """Build the emotion extraction prompt based on configuration. + + This method constructs a detailed prompt for the LLM that includes + instructions for emotion type classification, intensity assessment, + and optionally keyword extraction and subject classification. + + Args: + statement: The statement to analyze + extract_keywords: Whether to extract emotion keywords + enable_subject: Whether to enable subject classification + + Returns: + Formatted prompt string for LLM + """ + from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt + + prompt = await render_emotion_extraction_prompt( + statement=statement, + extract_keywords=extract_keywords, + enable_subject=enable_subject + ) + + return prompt + + async def _call_llm_structured( + self, + prompt: str, + model_id: Optional[str] = None + ) -> EmotionExtraction: + """Call LLM for structured emotion extraction output. + + This method uses the LLM client's response_structured method to get + a validated EmotionExtraction object from the LLM. + + Args: + prompt: The formatted prompt for emotion extraction + model_id: Optional model ID to use for this call + + Returns: + EmotionExtraction object with validated emotion data + + Raises: + LLMClientException: If LLM call fails or times out + ValidationError: If LLM response doesn't match expected schema + """ + try: + # Get LLM client + llm_client = self._get_llm_client(model_id) + + # Prepare messages + messages = [ + {"role": "user", "content": prompt} + ] + + # Call LLM with structured output + emotion = await llm_client.response_structured( + messages=messages, + response_model=EmotionExtraction, + temperature=0.3, + max_tokens=500 + ) + + return emotion + + except LLMClientException as e: + logger.error(f"LLM call failed: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error in LLM structured call: {str(e)}") + raise LLMClientException(f"Emotion extraction LLM call failed: {str(e)}") From fa6e1c9d937673694088be87f64c71d84e5fec01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Thu, 18 Dec 2025 09:56:35 +0000 Subject: [PATCH 21/24] Merge #13 into develop from fix/stream-output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'fix/stream-output' * fix/stream-output: (17 commits squashed) - [fix]Fix the issue where the streaming output effect is not obvious. - [fix]Fix the issue where the streaming output effect is not obvious. - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix] - [fix]Skip time extraction - [fix] - [fix]Skip time extraction - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix]Remove human-induced delays - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Remove human-induced delays - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/13 --- .../extraction_orchestrator.py | 173 +++--------------- 1 file changed, 29 insertions(+), 144 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 91529aa9..e00bcf0a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -35,6 +35,7 @@ from app.core.memory.models.graph_models import ( from app.core.memory.utils.data.ontology import TemporalInfo from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, + StatementExtractionConfig, ) from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient @@ -52,6 +53,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.tem ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, + embedding_generation_all, generate_entity_embeddings_from_triplets, ) from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( @@ -177,12 +179,24 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") + # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 + if self.progress_callback: + extraction_stats = { + "statements_count": total_statements, + "entities_count": 0, # 暂时为0,后续会更新 + "triplets_count": 0, # 暂时为0,后续会更新 + "temporal_ranges_count": 0, # 暂时为0,后续会更新 + } + await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) + + # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 + await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") + + # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") ( triplet_maps, temporal_maps, - emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -211,7 +225,6 @@ class ExtractionOrchestrator: dialog_data_list, temporal_maps, triplet_maps, - emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -539,108 +552,9 @@ class ExtractionOrchestrator: return temporal_maps - async def _extract_emotions( - self, dialog_data_list: List[DialogData] - ) -> List[Dict[str, Any]]: - """ - 从对话中提取情绪信息(优化版:全局陈述句级并行) - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 情绪信息映射列表,每个对话对应一个字典 - """ - logger.info("开始情绪信息提取(全局陈述句级并行)") - - # 收集所有陈述句及其配置 - all_statements = [] - statement_metadata = [] # (dialog_idx, statement_id) - - # 获取第一个对话的config_id来加载配置 - config_id = None - if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): - config_id = dialog_data_list[0].config_id - - # 加载DataConfig - data_config = None - if config_id: - try: - from app.db import SessionLocal - from app.repositories.data_config_repository import DataConfigRepository - - db = SessionLocal() - try: - data_config = DataConfigRepository.get_by_id(db, config_id) - finally: - db.close() - - if data_config and not data_config.emotion_enabled: - logger.info("情绪提取已在配置中禁用,跳过情绪提取") - return [{} for _ in dialog_data_list] - - except Exception as e: - logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") - return [{} for _ in dialog_data_list] - else: - logger.info("未找到config_id,跳过情绪提取") - return [{} for _ in dialog_data_list] - - # 如果配置未启用情绪提取,直接返回空映射 - if not data_config or not data_config.emotion_enabled: - logger.info("情绪提取未启用,跳过") - return [{} for _ in dialog_data_list] - - # 收集所有陈述句 - for d_idx, dialog in enumerate(dialog_data_list): - for chunk in dialog.chunks: - for statement in chunk.statements: - all_statements.append((statement, data_config)) - statement_metadata.append((d_idx, statement.id)) - - logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") - - # 初始化情绪提取服务 - from app.services.emotion_extraction_service import EmotionExtractionService - emotion_service = EmotionExtractionService( - llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None - ) - - # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): - statement, config = stmt_data - try: - return await emotion_service.extract_emotion(statement.statement, config) - except Exception as e: - logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}") - return None - - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 将结果组织成对话级别的映射 - emotion_maps = [{} for _ in dialog_data_list] - successful_extractions = 0 - - for i, result in enumerate(results): - d_idx, stmt_id = statement_metadata[i] - if isinstance(result, Exception): - logger.error(f"陈述句处理异常: {result}") - emotion_maps[d_idx][stmt_id] = None - else: - emotion_maps[d_idx][stmt_id] = result - if result is not None: - successful_extractions += 1 - - # 统计提取结果 - logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪") - - return emotion_maps - async def _parallel_extract_and_embed( self, dialog_data_list: List[DialogData] ) -> Tuple[ - List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, List[float]]], @@ -648,39 +562,35 @@ class ExtractionOrchestrator: List[List[float]], ]: """ - 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 + 并行执行三元组提取、时间信息提取和基础嵌入生成 - 这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: + 这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: - 三元组提取:从陈述句中提取实体和关系 - 时间信息提取:从陈述句中提取时间范围 - - 情绪提取:从陈述句中提取情绪信息 - 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组) Args: dialog_data_list: 对话数据列表 Returns: - 六个列表的元组: + 五个列表的元组: - 三元组映射列表 - 时间信息映射列表 - - 情绪映射列表 - 陈述句嵌入映射列表 - 分块嵌入映射列表 - 对话嵌入列表 """ - logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成") + logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成") - # 创建四个并行任务 + # 创建三个并行任务 triplet_task = self._extract_triplets(dialog_data_list) temporal_task = self._extract_temporal(dialog_data_list) - emotion_task = self._extract_emotions(dialog_data_list) embedding_task = self._generate_basic_embeddings(dialog_data_list) # 并行执行 results = await asyncio.gather( triplet_task, temporal_task, - emotion_task, embedding_task, return_exceptions=True ) @@ -688,21 +598,19 @@ class ExtractionOrchestrator: # 解包结果 triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] - emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - if isinstance(results[3], Exception): - logger.error(f"基础嵌入生成失败: {results[3]}") + if isinstance(results[2], Exception): + logger.error(f"基础嵌入生成失败: {results[2]}") statement_embedding_maps = [{} for _ in dialog_data_list] chunk_embedding_maps = [{} for _ in dialog_data_list] dialog_embeddings = [[] for _ in dialog_data_list] else: - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3] + statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2] logger.info("并行任务执行完成") return ( triplet_maps, temporal_maps, - emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -819,7 +727,6 @@ class ExtractionOrchestrator: dialog_data_list: List[DialogData], temporal_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]], - emotion_maps: List[Dict[str, Any]], statement_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]], dialog_embeddings: List[List[float]], @@ -831,7 +738,6 @@ class ExtractionOrchestrator: dialog_data_list: 对话数据列表 temporal_maps: 时间信息映射列表 triplet_maps: 三元组映射列表 - emotion_maps: 情绪信息映射列表 statement_embedding_maps: 陈述句嵌入映射列表 chunk_embedding_maps: 分块嵌入映射列表 dialog_embeddings: 对话嵌入列表 @@ -846,7 +752,6 @@ class ExtractionOrchestrator: if ( len(temporal_maps) != expected_length or len(triplet_maps) != expected_length - or len(emotion_maps) != expected_length or len(statement_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length or len(dialog_embeddings) != expected_length @@ -854,7 +759,6 @@ class ExtractionOrchestrator: logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, " - f"情绪映射: {len(emotion_maps)}, " f"陈述句嵌入: {len(statement_embedding_maps)}, " f"分块嵌入: {len(chunk_embedding_maps)}, " f"对话嵌入: {len(dialog_embeddings)}" @@ -863,7 +767,6 @@ class ExtractionOrchestrator: total_statements = 0 assigned_temporal = 0 assigned_triplets = 0 - assigned_emotions = 0 assigned_statement_embeddings = 0 assigned_chunk_embeddings = 0 assigned_dialog_embeddings = 0 @@ -871,13 +774,12 @@ class ExtractionOrchestrator: # 处理每个对话 for i, dialog_data in enumerate(dialog_data_list): # 检查是否有缺失的数据 - if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps): + if i >= len(temporal_maps) or i >= len(triplet_maps): logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值") continue temporal_map = temporal_maps[i] triplet_map = triplet_maps[i] - emotion_map = emotion_maps[i] statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {} chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {} dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else [] @@ -908,18 +810,6 @@ class ExtractionOrchestrator: statement.triplet_extraction_info = triplet_map[statement.id] assigned_triplets += 1 - # 赋值情绪信息 - if statement.id in emotion_map: - emotion_data = emotion_map[statement.id] - if emotion_data is not None: - # 将EmotionExtraction对象的字段赋值到Statement - statement.emotion_type = emotion_data.emotion_type - statement.emotion_intensity = emotion_data.emotion_intensity - statement.emotion_keywords = emotion_data.emotion_keywords - statement.emotion_subject = emotion_data.emotion_subject - statement.emotion_target = emotion_data.emotion_target - assigned_emotions += 1 - # 赋值陈述句嵌入 if statement.id in statement_embedding_map: statement.statement_embedding = statement_embedding_map[statement.id] @@ -928,7 +818,6 @@ class ExtractionOrchestrator: logger.info( f"数据赋值完成 - 总陈述句: {total_statements}, " f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, " - f"情绪信息: {assigned_emotions}, " f"陈述句嵌入: {assigned_statement_embeddings}, " f"分块嵌入: {assigned_chunk_embeddings}, " f"对话嵌入: {assigned_dialog_embeddings}" @@ -1038,12 +927,6 @@ class ExtractionOrchestrator: created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, - # Emotion fields - emotion_type=getattr(statement, 'emotion_type', None), - emotion_intensity=getattr(statement, 'emotion_intensity', None), - emotion_keywords=getattr(statement, 'emotion_keywords', None), - emotion_subject=getattr(statement, 'emotion_subject', None), - emotion_target=getattr(statement, 'emotion_target', None), ) statement_nodes.append(statement_node) @@ -1450,7 +1333,7 @@ class ExtractionOrchestrator: if match: entity1_name = match.group(1).strip() entity1_type = match.group(2) - match.group(3).strip() + entity2_name = match.group(3).strip() entity2_type = match.group(4) # 提取置信度和原因 @@ -1763,6 +1646,7 @@ async def get_chunked_dialogs( """ import json import re + import os # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") @@ -1938,6 +1822,7 @@ async def get_chunked_dialogs_with_preprocessing( Returns: 带 chunks 的 DialogData 列表 """ + import os print("\n=== 完整数据处理流程(包含预处理)===") if input_data_path is None: From 902dd18bc829f7ce0c55189d0134da35d2748992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Sat, 20 Dec 2025 07:02:46 +0000 Subject: [PATCH 22/24] Merge #21 into develop from feature/emotion-engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feature/情绪引擎 * feature/emotion-engine: (7 commits squashed) - [feature]Emotion Engine Development - [feature]Emotion Engine Development - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]fix vite.config.ts Signed-off-by: 乐力齐 Commented-by: aliyun6762716068 Commented-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/21 --- .../extraction_orchestrator.py | 173 +++++++++++++++--- 1 file changed, 144 insertions(+), 29 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index e00bcf0a..91529aa9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -35,7 +35,6 @@ from app.core.memory.models.graph_models import ( from app.core.memory.utils.data.ontology import TemporalInfo from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, - StatementExtractionConfig, ) from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient @@ -53,7 +52,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.tem ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, - embedding_generation_all, generate_entity_embeddings_from_triplets, ) from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( @@ -179,24 +177,12 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 - if self.progress_callback: - extraction_stats = { - "statements_count": total_statements, - "entities_count": 0, # 暂时为0,后续会更新 - "triplets_count": 0, # 暂时为0,后续会更新 - "temporal_ranges_count": 0, # 暂时为0,后续会更新 - } - await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) - - # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 - await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") - - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") + # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -225,6 +211,7 @@ class ExtractionOrchestrator: dialog_data_list, temporal_maps, triplet_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -552,9 +539,108 @@ class ExtractionOrchestrator: return temporal_maps + async def _extract_emotions( + self, dialog_data_list: List[DialogData] + ) -> List[Dict[str, Any]]: + """ + 从对话中提取情绪信息(优化版:全局陈述句级并行) + + Args: + dialog_data_list: 对话数据列表 + + Returns: + 情绪信息映射列表,每个对话对应一个字典 + """ + logger.info("开始情绪信息提取(全局陈述句级并行)") + + # 收集所有陈述句及其配置 + all_statements = [] + statement_metadata = [] # (dialog_idx, statement_id) + + # 获取第一个对话的config_id来加载配置 + config_id = None + if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): + config_id = dialog_data_list[0].config_id + + # 加载DataConfig + data_config = None + if config_id: + try: + from app.db import SessionLocal + from app.repositories.data_config_repository import DataConfigRepository + + db = SessionLocal() + try: + data_config = DataConfigRepository.get_by_id(db, config_id) + finally: + db.close() + + if data_config and not data_config.emotion_enabled: + logger.info("情绪提取已在配置中禁用,跳过情绪提取") + return [{} for _ in dialog_data_list] + + except Exception as e: + logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") + return [{} for _ in dialog_data_list] + else: + logger.info("未找到config_id,跳过情绪提取") + return [{} for _ in dialog_data_list] + + # 如果配置未启用情绪提取,直接返回空映射 + if not data_config or not data_config.emotion_enabled: + logger.info("情绪提取未启用,跳过") + return [{} for _ in dialog_data_list] + + # 收集所有陈述句 + for d_idx, dialog in enumerate(dialog_data_list): + for chunk in dialog.chunks: + for statement in chunk.statements: + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + + logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + + # 初始化情绪提取服务 + from app.services.emotion_extraction_service import EmotionExtractionService + emotion_service = EmotionExtractionService( + llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None + ) + + # 全局并行处理所有陈述句 + async def extract_for_statement(stmt_data): + statement, config = stmt_data + try: + return await emotion_service.extract_emotion(statement.statement, config) + except Exception as e: + logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}") + return None + + tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 将结果组织成对话级别的映射 + emotion_maps = [{} for _ in dialog_data_list] + successful_extractions = 0 + + for i, result in enumerate(results): + d_idx, stmt_id = statement_metadata[i] + if isinstance(result, Exception): + logger.error(f"陈述句处理异常: {result}") + emotion_maps[d_idx][stmt_id] = None + else: + emotion_maps[d_idx][stmt_id] = result + if result is not None: + successful_extractions += 1 + + # 统计提取结果 + logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪") + + return emotion_maps + async def _parallel_extract_and_embed( self, dialog_data_list: List[DialogData] ) -> Tuple[ + List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, List[float]]], @@ -562,35 +648,39 @@ class ExtractionOrchestrator: List[List[float]], ]: """ - 并行执行三元组提取、时间信息提取和基础嵌入生成 + 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 - 这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: + 这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: - 三元组提取:从陈述句中提取实体和关系 - 时间信息提取:从陈述句中提取时间范围 + - 情绪提取:从陈述句中提取情绪信息 - 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组) Args: dialog_data_list: 对话数据列表 Returns: - 五个列表的元组: + 六个列表的元组: - 三元组映射列表 - 时间信息映射列表 + - 情绪映射列表 - 陈述句嵌入映射列表 - 分块嵌入映射列表 - 对话嵌入列表 """ - logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成") + logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成") - # 创建三个并行任务 + # 创建四个并行任务 triplet_task = self._extract_triplets(dialog_data_list) temporal_task = self._extract_temporal(dialog_data_list) + emotion_task = self._extract_emotions(dialog_data_list) embedding_task = self._generate_basic_embeddings(dialog_data_list) # 并行执行 results = await asyncio.gather( triplet_task, temporal_task, + emotion_task, embedding_task, return_exceptions=True ) @@ -598,19 +688,21 @@ class ExtractionOrchestrator: # 解包结果 triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] + emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - if isinstance(results[2], Exception): - logger.error(f"基础嵌入生成失败: {results[2]}") + if isinstance(results[3], Exception): + logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] chunk_embedding_maps = [{} for _ in dialog_data_list] dialog_embeddings = [[] for _ in dialog_data_list] else: - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2] + statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3] logger.info("并行任务执行完成") return ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -727,6 +819,7 @@ class ExtractionOrchestrator: dialog_data_list: List[DialogData], temporal_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], statement_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]], dialog_embeddings: List[List[float]], @@ -738,6 +831,7 @@ class ExtractionOrchestrator: dialog_data_list: 对话数据列表 temporal_maps: 时间信息映射列表 triplet_maps: 三元组映射列表 + emotion_maps: 情绪信息映射列表 statement_embedding_maps: 陈述句嵌入映射列表 chunk_embedding_maps: 分块嵌入映射列表 dialog_embeddings: 对话嵌入列表 @@ -752,6 +846,7 @@ class ExtractionOrchestrator: if ( len(temporal_maps) != expected_length or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length or len(statement_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length or len(dialog_embeddings) != expected_length @@ -759,6 +854,7 @@ class ExtractionOrchestrator: logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, " + f"情绪映射: {len(emotion_maps)}, " f"陈述句嵌入: {len(statement_embedding_maps)}, " f"分块嵌入: {len(chunk_embedding_maps)}, " f"对话嵌入: {len(dialog_embeddings)}" @@ -767,6 +863,7 @@ class ExtractionOrchestrator: total_statements = 0 assigned_temporal = 0 assigned_triplets = 0 + assigned_emotions = 0 assigned_statement_embeddings = 0 assigned_chunk_embeddings = 0 assigned_dialog_embeddings = 0 @@ -774,12 +871,13 @@ class ExtractionOrchestrator: # 处理每个对话 for i, dialog_data in enumerate(dialog_data_list): # 检查是否有缺失的数据 - if i >= len(temporal_maps) or i >= len(triplet_maps): + if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps): logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值") continue temporal_map = temporal_maps[i] triplet_map = triplet_maps[i] + emotion_map = emotion_maps[i] statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {} chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {} dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else [] @@ -810,6 +908,18 @@ class ExtractionOrchestrator: statement.triplet_extraction_info = triplet_map[statement.id] assigned_triplets += 1 + # 赋值情绪信息 + if statement.id in emotion_map: + emotion_data = emotion_map[statement.id] + if emotion_data is not None: + # 将EmotionExtraction对象的字段赋值到Statement + statement.emotion_type = emotion_data.emotion_type + statement.emotion_intensity = emotion_data.emotion_intensity + statement.emotion_keywords = emotion_data.emotion_keywords + statement.emotion_subject = emotion_data.emotion_subject + statement.emotion_target = emotion_data.emotion_target + assigned_emotions += 1 + # 赋值陈述句嵌入 if statement.id in statement_embedding_map: statement.statement_embedding = statement_embedding_map[statement.id] @@ -818,6 +928,7 @@ class ExtractionOrchestrator: logger.info( f"数据赋值完成 - 总陈述句: {total_statements}, " f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, " + f"情绪信息: {assigned_emotions}, " f"陈述句嵌入: {assigned_statement_embeddings}, " f"分块嵌入: {assigned_chunk_embeddings}, " f"对话嵌入: {assigned_dialog_embeddings}" @@ -927,6 +1038,12 @@ class ExtractionOrchestrator: created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, + # Emotion fields + emotion_type=getattr(statement, 'emotion_type', None), + emotion_intensity=getattr(statement, 'emotion_intensity', None), + emotion_keywords=getattr(statement, 'emotion_keywords', None), + emotion_subject=getattr(statement, 'emotion_subject', None), + emotion_target=getattr(statement, 'emotion_target', None), ) statement_nodes.append(statement_node) @@ -1333,7 +1450,7 @@ class ExtractionOrchestrator: if match: entity1_name = match.group(1).strip() entity1_type = match.group(2) - entity2_name = match.group(3).strip() + match.group(3).strip() entity2_type = match.group(4) # 提取置信度和原因 @@ -1646,7 +1763,6 @@ async def get_chunked_dialogs( """ import json import re - import os # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") @@ -1822,7 +1938,6 @@ async def get_chunked_dialogs_with_preprocessing( Returns: 带 chunks 的 DialogData 列表 """ - import os print("\n=== 完整数据处理流程(包含预处理)===") if input_data_path is None: From e4f7fb43f577fd03592f74f270b061205cb90df9 Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 15:27:47 +0800 Subject: [PATCH 23/24] [add] migration script --- .../versions/626abf154a6a_202512201526.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 api/migrations/versions/626abf154a6a_202512201526.py diff --git a/api/migrations/versions/626abf154a6a_202512201526.py b/api/migrations/versions/626abf154a6a_202512201526.py new file mode 100644 index 00000000..7d89766e --- /dev/null +++ b/api/migrations/versions/626abf154a6a_202512201526.py @@ -0,0 +1,38 @@ +"""202512201526 + +Revision ID: 626abf154a6a +Revises: 70e94dd4a8d1 +Create Date: 2025-12-20 15:26:50.634470 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '626abf154a6a' +down_revision: Union[str, None] = '70e94dd4a8d1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('data_config', sa.Column('emotion_enabled', sa.Boolean(), nullable=True, comment='是否启用情绪提取')) + op.add_column('data_config', sa.Column('emotion_model_id', sa.String(), nullable=True, comment='情绪分析专用模型ID')) + op.add_column('data_config', sa.Column('emotion_extract_keywords', sa.Boolean(), nullable=True, comment='是否提取情绪关键词')) + op.add_column('data_config', sa.Column('emotion_min_intensity', sa.Float(), nullable=True, comment='最小情绪强度阈值')) + op.add_column('data_config', sa.Column('emotion_enable_subject', sa.Boolean(), nullable=True, comment='是否启用主体分类')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('data_config', 'emotion_enable_subject') + op.drop_column('data_config', 'emotion_min_intensity') + op.drop_column('data_config', 'emotion_extract_keywords') + op.drop_column('data_config', 'emotion_model_id') + op.drop_column('data_config', 'emotion_enabled') + # ### end Alembic commands ### From b00d6e37e310352171d928800f6ad5c901b1fca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Sat, 20 Dec 2025 16:03:06 +0800 Subject: [PATCH 24/24] feat(tool system): tool system development --- api/app/core/workflow/executor.py | 356 +++++++++++++++--------------- api/app/services/agent_tools.py | 219 +----------------- 2 files changed, 179 insertions(+), 396 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 6effaa5b..46f8cf08 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -5,7 +5,7 @@ """ import logging -import uuid +# import uuid import datetime from typing import Any @@ -16,10 +16,11 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes.enums import NodeType -from app.core.tools.registry import ToolRegistry -from app.core.tools.executor import ToolExecutor -from app.core.tools.langchain_adapter import LangchainAdapter -TOOL_MANAGEMENT_AVAILABLE = True +# from app.core.tools.registry import ToolRegistry +# from app.core.tools.executor import ToolExecutor +# from app.core.tools.langchain_adapter import LangchainAdapter +# TOOL_MANAGEMENT_AVAILABLE = True +# from app.db import get_db logger = logging.getLogger(__name__) @@ -466,176 +467,175 @@ async def execute_workflow_stream( # ==================== 工具管理系统集成 ==================== -def get_workflow_tools(workspace_id: str, user_id: str) -> list: - """获取工作流可用的工具列表 - - Args: - workspace_id: 工作空间ID - user_id: 用户ID - - Returns: - 可用工具列表 - """ - if not TOOL_MANAGEMENT_AVAILABLE: - logger.warning("工具管理系统不可用") - return [] - - try: - from sqlalchemy.orm import Session - db = next(get_db()) - - # 创建工具注册表 - registry = ToolRegistry(db) - - # 注册内置工具类 - from app.core.tools.builtin import ( - DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool - ) - registry.register_tool_class(DateTimeTool) - registry.register_tool_class(JsonTool) - registry.register_tool_class(BaiduSearchTool) - registry.register_tool_class(MinerUTool) - registry.register_tool_class(TextInTool) - - # 获取活跃的工具 - import uuid - tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) - active_tools = [tool for tool in tools if tool.status.value == "active"] - - # 转换为Langchain工具 - langchain_tools = [] - for tool_info in active_tools: - try: - tool_instance = registry.get_tool(tool_info.id) - if tool_instance: - langchain_tool = LangchainAdapter.convert_tool(tool_instance) - langchain_tools.append(langchain_tool) - except Exception as e: - logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") - - logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") - return langchain_tools - - except Exception as e: - logger.error(f"获取工作流工具失败: {e}") - return [] - - -class ToolWorkflowNode: - """工具工作流节点 - 在工作流中执行工具""" - - def __init__(self, node_config: dict, workflow_config: dict): - """初始化工具节点 - - Args: - node_config: 节点配置 - workflow_config: 工作流配置 - """ - self.node_config = node_config - self.workflow_config = workflow_config - self.tool_id = node_config.get("tool_id") - self.tool_parameters = node_config.get("parameters", {}) - - async def run(self, state: WorkflowState) -> WorkflowState: - """执行工具节点""" - if not TOOL_MANAGEMENT_AVAILABLE: - logger.error("工具管理系统不可用") - state["error"] = "工具管理系统不可用" - return state - - try: - from sqlalchemy.orm import Session - db = next(get_db()) - - # 创建工具执行器 - registry = ToolRegistry(db) - executor = ToolExecutor(db, registry) - - # 准备参数(支持变量替换) - parameters = self._prepare_parameters(state) - - # 执行工具 - result = await executor.execute_tool( - tool_id=self.tool_id, - parameters=parameters, - user_id=uuid.UUID(state["user_id"]), - workspace_id=uuid.UUID(state["workspace_id"]) - ) - - # 更新状态 - node_id = self.node_config.get("id") - if result.success: - state["node_outputs"][node_id] = { - "type": "tool", - "tool_id": self.tool_id, - "output": result.data, - "execution_time": result.execution_time, - "token_usage": result.token_usage - } - - # 更新运行时变量 - if isinstance(result.data, dict): - for key, value in result.data.items(): - state["runtime_vars"][f"{node_id}.{key}"] = value - else: - state["runtime_vars"][f"{node_id}.result"] = result.data - else: - state["error"] = result.error - state["error_node"] = node_id - state["node_outputs"][node_id] = { - "type": "tool", - "tool_id": self.tool_id, - "error": result.error, - "execution_time": result.execution_time - } - - return state - - except Exception as e: - logger.error(f"工具节点执行失败: {e}") - state["error"] = str(e) - state["error_node"] = self.node_config.get("id") - return state - - def _prepare_parameters(self, state: WorkflowState) -> dict: - """准备工具参数(支持变量替换)""" - parameters = {} - - for key, value in self.tool_parameters.items(): - if isinstance(value, str) and value.startswith("${") and value.endswith("}"): - # 变量替换 - var_path = value[2:-1] - - # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} - if "." in var_path: - parts = var_path.split(".") - current = state.get("variables", {}) - - for part in parts: - if isinstance(current, dict) and part in current: - current = current[part] - else: - # 尝试从运行时变量获取 - runtime_key = ".".join(parts) - current = state.get("runtime_vars", {}).get(runtime_key, value) - break - - parameters[key] = current - else: - # 简单变量 - variables = state.get("variables", {}) - parameters[key] = variables.get(var_path, value) - else: - parameters[key] = value - - return parameters - - -# 注册工具节点到NodeFactory(如果存在) -try: - from app.core.workflow.nodes import NodeFactory - if hasattr(NodeFactory, 'register_node_type'): - NodeFactory.register_node_type("tool", ToolWorkflowNode) - logger.info("工具节点已注册到工作流系统") -except Exception as e: - logger.warning(f"注册工具节点失败: {e}") \ No newline at end of file +# def get_workflow_tools(workspace_id: str, user_id: str) -> list: +# """获取工作流可用的工具列表 +# +# Args: +# workspace_id: 工作空间ID +# user_id: 用户ID +# +# Returns: +# 可用工具列表 +# """ +# if not TOOL_MANAGEMENT_AVAILABLE: +# logger.warning("工具管理系统不可用") +# return [] +# +# try: +# db = next(get_db()) +# +# # 创建工具注册表 +# registry = ToolRegistry(db) +# +# # 注册内置工具类 +# from app.core.tools.builtin import ( +# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool +# ) +# registry.register_tool_class(DateTimeTool) +# registry.register_tool_class(JsonTool) +# registry.register_tool_class(BaiduSearchTool) +# registry.register_tool_class(MinerUTool) +# registry.register_tool_class(TextInTool) +# +# # 获取活跃的工具 +# import uuid +# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) +# active_tools = [tool for tool in tools if tool.status.value == "active"] +# +# # 转换为Langchain工具 +# langchain_tools = [] +# for tool_info in active_tools: +# try: +# tool_instance = registry.get_tool(tool_info.id) +# if tool_instance: +# langchain_tool = LangchainAdapter.convert_tool(tool_instance) +# langchain_tools.append(langchain_tool) +# except Exception as e: +# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") +# +# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") +# return langchain_tools +# +# except Exception as e: +# logger.error(f"获取工作流工具失败: {e}") +# return [] +# +# +# class ToolWorkflowNode: +# """工具工作流节点 - 在工作流中执行工具""" +# +# def __init__(self, node_config: dict, workflow_config: dict): +# """初始化工具节点 +# +# Args: +# node_config: 节点配置 +# workflow_config: 工作流配置 +# """ +# self.node_config = node_config +# self.workflow_config = workflow_config +# self.tool_id = node_config.get("tool_id") +# self.tool_parameters = node_config.get("parameters", {}) +# +# async def run(self, state: WorkflowState) -> WorkflowState: +# """执行工具节点""" +# if not TOOL_MANAGEMENT_AVAILABLE: +# logger.error("工具管理系统不可用") +# state["error"] = "工具管理系统不可用" +# return state +# +# try: +# from sqlalchemy.orm import Session +# db = next(get_db()) +# +# # 创建工具执行器 +# registry = ToolRegistry(db) +# executor = ToolExecutor(db, registry) +# +# # 准备参数(支持变量替换) +# parameters = self._prepare_parameters(state) +# +# # 执行工具 +# result = await executor.execute_tool( +# tool_id=self.tool_id, +# parameters=parameters, +# user_id=uuid.UUID(state["user_id"]), +# workspace_id=uuid.UUID(state["workspace_id"]) +# ) +# +# # 更新状态 +# node_id = self.node_config.get("id") +# if result.success: +# state["node_outputs"][node_id] = { +# "type": "tool", +# "tool_id": self.tool_id, +# "output": result.data, +# "execution_time": result.execution_time, +# "token_usage": result.token_usage +# } +# +# # 更新运行时变量 +# if isinstance(result.data, dict): +# for key, value in result.data.items(): +# state["runtime_vars"][f"{node_id}.{key}"] = value +# else: +# state["runtime_vars"][f"{node_id}.result"] = result.data +# else: +# state["error"] = result.error +# state["error_node"] = node_id +# state["node_outputs"][node_id] = { +# "type": "tool", +# "tool_id": self.tool_id, +# "error": result.error, +# "execution_time": result.execution_time +# } +# +# return state +# +# except Exception as e: +# logger.error(f"工具节点执行失败: {e}") +# state["error"] = str(e) +# state["error_node"] = self.node_config.get("id") +# return state +# +# def _prepare_parameters(self, state: WorkflowState) -> dict: +# """准备工具参数(支持变量替换)""" +# parameters = {} +# +# for key, value in self.tool_parameters.items(): +# if isinstance(value, str) and value.startswith("${") and value.endswith("}"): +# # 变量替换 +# var_path = value[2:-1] +# +# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} +# if "." in var_path: +# parts = var_path.split(".") +# current = state.get("variables", {}) +# +# for part in parts: +# if isinstance(current, dict) and part in current: +# current = current[part] +# else: +# # 尝试从运行时变量获取 +# runtime_key = ".".join(parts) +# current = state.get("runtime_vars", {}).get(runtime_key, value) +# break +# +# parameters[key] = current +# else: +# # 简单变量 +# variables = state.get("variables", {}) +# parameters[key] = variables.get(var_path, value) +# else: +# parameters[key] = value +# +# return parameters +# +# +# # 注册工具节点到NodeFactory(如果存在) +# try: +# from app.core.workflow.nodes import NodeFactory +# if hasattr(NodeFactory, 'register_node_type'): +# NodeFactory.register_node_type("tool", ToolWorkflowNode) +# logger.info("工具节点已注册到工作流系统") +# except Exception as e: +# logger.warning(f"注册工具节点失败: {e}") \ No newline at end of file diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 7fe6a0c0..3ca7bddd 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -13,10 +13,6 @@ from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger from app.repositories import workspace_repository, knowledge_repository -from app.core.tools.registry import ToolRegistry -from app.core.tools.executor import ToolExecutor -from app.core.tools.langchain_adapter import LangchainAdapter -TOOL_MANAGEMENT_AVAILABLE = True logger = get_business_logger() @@ -333,217 +329,4 @@ def create_agent_invocation_tool( ) return f"调用 Agent 失败: {str(e)}" - return invoke_agent - -def get_available_tools_for_agent( - db: Session, - workspace_id: uuid.UUID, - agent_id: Optional[uuid.UUID] = None -) -> List[Dict[str, Any]]: - """获取Agent可用的工具列表 - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - agent_id: Agent ID(可选) - - Returns: - 可用工具列表 - """ - if not TOOL_MANAGEMENT_AVAILABLE: - logger.warning("工具管理系统不可用") - return [] - - try: - # 创建工具注册表 - registry = ToolRegistry(db) - - # 获取工具列表 - tools = registry.list_tools(workspace_id=workspace_id) - - # 转换为Agent可用的格式 - available_tools = [] - for tool_info in tools: - if tool_info.status.value == "active": - available_tools.append({ - "id": tool_info.id, - "name": tool_info.name, - "description": tool_info.description, - "type": tool_info.tool_type.value, - "version": tool_info.version, - "tags": tool_info.tags, - "parameters": [ - { - "name": param.name, - "type": param.type.value, - "description": param.description, - "required": param.required, - "default": param.default - } - for param in tool_info.parameters - ] - }) - - logger.info(f"为Agent获取到 {len(available_tools)} 个可用工具") - return available_tools - - except Exception as e: - logger.error(f"获取Agent可用工具失败: {e}") - return [] - - -def create_langchain_tools_for_agent( - db: Session, - workspace_id: uuid.UUID, - agent_id: Optional[uuid.UUID] = None -) -> List[Any]: - """为Agent创建Langchain兼容的工具列表 - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - agent_id: Agent ID(可选) - - Returns: - Langchain工具列表 - """ - if not TOOL_MANAGEMENT_AVAILABLE: - logger.warning("工具管理系统不可用") - return [] - - try: - # 创建工具注册表 - registry = ToolRegistry(db) - - # 注册内置工具类 - from app.core.tools.builtin import ( - DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool - ) - registry.register_tool_class(DateTimeTool) - registry.register_tool_class(JsonTool) - registry.register_tool_class(BaiduSearchTool) - registry.register_tool_class(MinerUTool) - registry.register_tool_class(TextInTool) - - # 获取活跃的工具 - tools = registry.list_tools(workspace_id=workspace_id) - active_tools = [tool for tool in tools if tool.status.value == "active"] - - # 转换为Langchain工具 - langchain_tools = [] - for tool_info in active_tools: - try: - tool_instance = registry.get_tool(tool_info.id) - if tool_instance: - langchain_tool = LangchainAdapter.convert_tool(tool_instance) - langchain_tools.append(langchain_tool) - except Exception as e: - logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") - - logger.info(f"为Agent创建了 {len(langchain_tools)} 个Langchain工具") - return langchain_tools - - except Exception as e: - logger.error(f"创建Agent Langchain工具失败: {e}") - return [] - - -class ToolExecutionInput(BaseModel): - """工具执行输入参数""" - tool_id: str = Field(..., description="工具ID") - parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") - timeout: Optional[float] = Field(None, description="超时时间(秒)") - - -def create_tool_execution_tool( - db: Session, - workspace_id: uuid.UUID, - user_id: uuid.UUID -): - """创建工具执行工具 - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - user_id: 用户ID - - Returns: - 工具执行工具 - """ - if not TOOL_MANAGEMENT_AVAILABLE: - logger.warning("工具管理系统不可用") - return None - - @tool(args_schema=ToolExecutionInput) - async def execute_tool( - tool_id: str, - parameters: Dict[str, Any] = None, - timeout: Optional[float] = None - ) -> str: - """执行指定的工具。当需要使用系统中的工具来完成特定任务时使用。 - - Args: - tool_id: 工具ID(通过工具列表获取) - parameters: 工具参数(根据工具要求提供) - timeout: 超时时间(秒,可选) - - Returns: - 工具执行结果 - """ - try: - # 创建工具执行器 - registry = ToolRegistry(db) - executor = ToolExecutor(db, registry) - - # 执行工具 - result = await executor.execute_tool( - tool_id=tool_id, - parameters=parameters or {}, - user_id=user_id, - workspace_id=workspace_id, - timeout=timeout - ) - - if result.success: - # 格式化成功结果 - if isinstance(result.data, str): - return result.data - else: - import json - return json.dumps(result.data, ensure_ascii=False, indent=2) - else: - return f"工具执行失败: {result.error}" - - except Exception as e: - logger.error(f"工具执行异常: {tool_id}, 错误: {e}") - return f"工具执行异常: {str(e)}" - - return execute_tool - - -def get_tool_management_tools( - db: Session, - workspace_id: uuid.UUID, - user_id: uuid.UUID -) -> List[Any]: - """获取工具管理相关的工具 - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - user_id: 用户ID - - Returns: - 工具管理工具列表 - """ - if not TOOL_MANAGEMENT_AVAILABLE: - return [] - - tools = [] - - # 添加工具执行工具 - execution_tool = create_tool_execution_tool(db, workspace_id, user_id) - if execution_tool: - tools.append(execution_tool) - - return tools \ No newline at end of file + return invoke_agent \ No newline at end of file