diff --git a/api/app/controllers/home_page_controller.py b/api/app/controllers/home_page_controller.py index 6665eec1..f1a5310d 100644 --- a/api/app/controllers/home_page_controller.py +++ b/api/app/controllers/home_page_controller.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session +from app.core.config import settings from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user @@ -26,4 +27,9 @@ def get_workspace_list( ): """获取工作空间列表""" workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id) - return success(data=workspace_list, msg="工作空间列表获取成功") \ No newline at end of file + return success(data=workspace_list, msg="工作空间列表获取成功") + +@router.get("/version", response_model=ApiResponse) +def get_system_version(): + """获取系统版本号""" + return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功") \ No newline at end of file diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index adb199fb..02c73718 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -433,7 +433,8 @@ async def chat( config=agent_config, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -469,7 +470,8 @@ async def chat( web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: @@ -486,8 +488,8 @@ async def chat( config=config, web_search=payload.web_search, memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 54af0b57..583b4700 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -154,7 +154,8 @@ async def chat( config=agent_config, memory=memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -178,7 +179,8 @@ async def chat( web_search=web_search, memory=memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: @@ -195,8 +197,8 @@ async def chat( config=config, web_search=web_search, memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -212,7 +214,6 @@ async def chat( # 多 Agent 非流式返回 result = await app_chat_service.multi_agent_chat( - message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID user_id=end_user_id, # 转换为字符串 @@ -291,4 +292,4 @@ async def chat( from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 380b660c..ef9a489f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,7 +7,6 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence @@ -97,8 +96,7 @@ class LangChainAgent: "temperature": temperature, "streaming": streaming, "tool_count": len(self.tools), - "tool_names": [tool.name for tool in self.tools] if self.tools else [], - "tool_count": len(self.tools) + "tool_names": [tool.name for tool in self.tools] if self.tools else [] } ) @@ -139,8 +137,11 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages + async def term_memory_save(self,messages,end_user_end,aimessages): - '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' + """ + 短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j + """ end_user_end=f"Term_{end_user_end}" print(messages) print(aimessages) @@ -154,6 +155,7 @@ class LangChainAgent: store.delete_duplicate_sessions() # logger.info(f'Redis_Agent:{end_user_end};{session_id}') return session_id + async def term_memory_redis_read(self,end_user_end): end_user_end = f"Term_{end_user_end}" history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) diff --git a/api/app/core/config.py b/api/app/core/config.py index 7494b89d..b02b94a5 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -164,6 +164,9 @@ class Settings: TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10")) ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" + + # official environment system version + SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0") def get_memory_output_path(self, filename: str = "") -> str: """ diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py index ec15c50f..2cdc0f60 100644 --- a/api/app/core/tools/base.py +++ b/api/app/core/tools/base.py @@ -191,10 +191,14 @@ class BaseTool(ABC): execution_time=execution_time ) - def to_langchain_tool(self): - """转换为Langchain工具格式""" + def to_langchain_tool(self, operation: Optional[str] = None): + """转换为Langchain工具格式 + + Args: + operation: 特定操作(适用于有操作的工具) + """ from app.core.tools.langchain_adapter import LangchainAdapter - return LangchainAdapter.convert_tool(self) + return LangchainAdapter.convert_tool(self, operation) def __repr__(self): return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" \ No newline at end of file diff --git a/api/app/core/tools/builtin/operation_tool.py b/api/app/core/tools/builtin/operation_tool.py new file mode 100644 index 00000000..126541a8 --- /dev/null +++ b/api/app/core/tools/builtin/operation_tool.py @@ -0,0 +1,216 @@ +"""操作工具 - 为特定操作创建的工具包装器""" +from typing import List +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.models import ToolType + + +class OperationTool(BaseTool): + """操作工具 - 包装基础工具的特定操作""" + + def __init__(self, base_tool: BaseTool, operation: str): + self.base_tool = base_tool + self.operation = operation + super().__init__(base_tool.tool_id, base_tool.config) + + @property + def name(self) -> str: + return f"{self.base_tool.name}_{self.operation}" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.BUILTIN + + @property + def description(self) -> str: + return f"{self.base_tool.description} - {self.operation}" + + @property + def parameters(self) -> List[ToolParameter]: + """返回特定操作的参数""" + if self.base_tool.name == 'datetime_tool': + return self._get_datetime_params() + elif self.base_tool.name == 'json_tool': + return self._get_json_params() + else: + # 默认返回除operation外的所有参数 + return [p for p in self.base_tool.parameters if p.name != "operation"] + + def _get_datetime_params(self) -> List[ToolParameter]: + """获取datetime_tool特定操作的参数""" + if self.operation == "now": + return [ + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ) + ] + elif self.operation == "format": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ) + ] + elif self.operation == "convert_timezone": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="from_timezone", + type=ParameterType.STRING, + description="源时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ) + ] + elif self.operation == "timestamp_to_datetime": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ) + ] + else: + return [] + + def _get_json_params(self) -> List[ToolParameter]: + """获取json_tool特定操作的参数""" + base_params = [ + ToolParameter( + name="input_data", + type=ParameterType.STRING, + description="输入数据(JSON字符串、YAML字符串或XML字符串)", + required=True + ) + ] + + if self.operation == "insert": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ), + ToolParameter( + name="new_value", + type=ParameterType.STRING, + description="新值(用于insert操作)", + required=True + ) + ] + elif self.operation == "replace": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ), + ToolParameter( + name="old_text", + type=ParameterType.STRING, + description="要替换的原文本(用于replace操作)", + required=True + ), + ToolParameter( + name="new_text", + type=ParameterType.STRING, + description="替换后的新文本(用于replace操作)", + required=True + ) + ] + elif self.operation == "delete": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ) + ] + elif self.operation == "parse": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ) + ] + else: + return base_params + + async def execute(self, **kwargs) -> ToolResult: + """执行特定操作""" + # 添加operation参数 + kwargs["operation"] = self.operation + return await self.base_tool.execute(**kwargs) \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index 0d656a8e..3dfe4c93 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -1,4 +1,5 @@ """自定义工具基类""" +import json import time from typing import Dict, Any, List, Optional import aiohttp @@ -135,6 +136,13 @@ class CustomTool(BaseTool): if not self.schema_content: return operations + + if isinstance(self.schema_content, str): + try: + self.schema_content = json.loads(self.schema_content) + except json.JSONDecodeError: + logger.error(f"无效的OpenAPI schema: {self.schema_content}") + return operations paths = self.schema_content.get("paths", {}) diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index 1b6969b9..89ccc205 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool): name=tool_instance.name, description=tool_instance.description, args_schema=args_schema, - _tool_instance=tool_instance, + tool_instance=tool_instance, **kwargs ) @@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool): """异步执行工具""" try: # 执行内部工具 - result = await self._tool_instance.safe_execute(**kwargs) + result = await self.tool_instance.safe_execute(**kwargs) # 转换结果为Langchain格式 return LangchainAdapter._format_result_for_langchain(result) @@ -73,24 +73,39 @@ class LangchainAdapter: """Langchain适配器 - 负责工具格式转换和标准化""" @staticmethod - def convert_tool(tool: BaseTool) -> LangchainToolWrapper: + def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper: """将内部工具转换为Langchain工具 Args: tool: 内部工具实例 + operation: 特定操作(适用于有操作的工具) Returns: Langchain兼容的工具包装器 """ try: - wrapper = LangchainToolWrapper(tool_instance=tool) - logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") - return wrapper + if operation and tool.name in ['datetime_tool', 'json_tool']: + # 为特定操作创建工具 + operation_tool = LangchainAdapter._create_operation_tool(tool, operation) + wrapper = LangchainToolWrapper(tool_instance=operation_tool) + logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式") + return wrapper + else: + # 单个工具 + wrapper = LangchainToolWrapper(tool_instance=tool) + logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") + return wrapper except Exception as e: logger.error(f"工具转换失败: {tool.name}, 错误: {e}") raise + @staticmethod + def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool: + """为特定操作创建工具实例""" + from app.core.tools.builtin.operation_tool import OperationTool + return OperationTool(base_tool, operation) + @staticmethod def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: """批量转换工具 @@ -110,7 +125,7 @@ class LangchainAdapter: except Exception as e: logger.error(f"跳过工具转换: {tool.name}, 错误: {e}") - logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具") + logger.info(f"批量转换完成: {len(converted_tools)} 个工具") return converted_tools @staticmethod @@ -169,9 +184,10 @@ class LangchainAdapter: "ToolArgsSchema", (BaseModel,), { + "__module__": __name__, "__annotations__": annotations, - **fields, - "Config": type("Config", (), {"extra": "forbid"}) + "model_config": {"extra": "forbid"}, + **fields } ) diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py index f7349201..01312444 100644 --- a/api/app/core/tools/mcp/service_manager.py +++ b/api/app/core/tools/mcp/service_manager.py @@ -16,14 +16,17 @@ logger = get_business_logger() class MCPServiceManager: """MCP服务管理器 - 管理MCP服务的生命周期""" - def __init__(self, db: Session): + def __init__(self, db: Session = None): """初始化MCP服务管理器 Args: - db: 数据库会话 + db: 数据库会话(可选) """ self.db = db - self.connection_pool = MCPConnectionPool(max_connections=20) + if db: + self.connection_pool = MCPConnectionPool(max_connections=20) + else: + self.connection_pool = None # 服务状态管理 self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info @@ -592,7 +595,7 @@ class MCPServiceManager: except Exception as e: logger.error(f"清理失效服务失败: {e}") - + def get_manager_status(self) -> Dict[str, Any]: """获取管理器状态""" return { diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 81cd704d..d20570ce 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Union from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel): class ToolConfig(BaseModel): + """工具配置""" + enabled: bool = Field(default=False, description="是否启用该工具") + tool_id: str = Field(default=None, description="工具ID") + operation: Optional[str] = Field(default_factory=dict, description="工具特定配置") + +class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") @@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel): ) # 工具配置 - tools: Dict[str, ToolConfig] = Field( - default_factory=dict, - description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)" + tools: List[ToolConfig] = Field( + default_factory=list, + description="Agent 可用的工具列表" ) @@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel): variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") # 工具配置 - tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") + tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表") # ---------- Output Schemas ---------- @@ -216,7 +222,7 @@ class AgentConfig(BaseModel): variables: List[VariableDefinition] = [] # 工具配置 - tools: Dict[str, ToolConfig] = {} + tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] is_active: bool created_at: datetime.datetime diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 3ab14157..eda4b5c4 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -2,14 +2,14 @@ Agent 配置格式转换器 用于将 Pydantic 模型转换为数据库存储格式 """ -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Union from app.schemas.app_schema import ( KnowledgeRetrievalConfig, MemoryConfig, VariableDefinition, ToolConfig, AgentConfigCreate, - AgentConfigUpdate, + AgentConfigUpdate, ToolOldConfig, ) @@ -47,10 +47,7 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: - result["tools"] = { - name: tool.model_dump() - for name, tool in config.tools.items() - } + result["tools"] = [tool.model_dump() for tool in config.tools] return result @@ -60,7 +57,7 @@ class AgentConfigConverter: knowledge_retrieval: Optional[Dict[str, Any]], memory: Optional[Dict[str, Any]], variables: Optional[list], - tools: Optional[Dict[str, Any]], + tools: Optional[Union[list, Dict[str, Any]]], ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -113,9 +110,12 @@ class AgentConfigConverter: # 5. 解析工具配置 if tools: - result["tools"] = { - name: ToolConfig(**tool_data) - for name, tool_data in tools.items() - } + if isinstance(tools, list): + result["tools"] = [ToolConfig(**tool_config) for tool_config in tools] + else: + result["tools"] = { + name: ToolOldConfig(**tool_data) + for name, tool_data in tools.items() + } return result diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 6b7b3103..537eac8d 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -14,6 +14,10 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db, get_db_context from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig +from app.services.tool_service import ToolService +from app.repositories.tool_repository import ToolRepository +from app.db import get_db +from app.models import MultiAgentConfig, AgentConfig from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool @@ -43,6 +47,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" @@ -68,6 +73,24 @@ class AppChatService: # 准备工具列表 tools = [] + # 获取工具服务 + tool_service = ToolService(self.db) + + # 从配置中获取启用的工具 + if hasattr(config, 'tools') and config.tools: + for tool_config in config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: @@ -86,7 +109,7 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.tools + # web_tools = config.tools # web_search_choice = web_tools.get("web_search", {}) # web_search_enable = web_search_choice.get("enabled", False) # if web_search == True: @@ -173,6 +196,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + workspace_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """聊天(流式)""" @@ -199,6 +223,23 @@ class AppChatService: # 准备工具列表 tools = [] + # 获取工具服务 + tool_service = ToolService(self.db) + + if hasattr(config, 'tools') and config.tools: + for tool_config in config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: @@ -217,20 +258,20 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + # web_tools = config.tools + # web_search_choice = web_tools.get("web_search", {}) + # web_search_enable = web_search_choice.get("enabled", False) + # if web_search == True: + # if web_search_enable == True: + # search_tool = create_web_search_tool({}) + # tools.append(search_tool) + # + # logger.debug( + # "已添加网络搜索工具", + # extra={ + # "tool_count": len(tools) + # } + # ) # 获取模型参数 model_parameters = config.model_parameters diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 38097c4e..e15f68fe 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -307,7 +307,7 @@ class AppService: knowledge_retrieval=storage_data.get("knowledge_retrieval"), memory=storage_data.get("memory"), variables=storage_data.get("variables", []), - tools=storage_data.get("tools", {}), + tools=storage_data.get("tools", []), is_active=True, created_at=now, updated_at=now, @@ -689,7 +689,7 @@ class AppService: knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, memory=source_config.memory.copy() if source_config.memory else None, variables=source_config.variables.copy() if source_config.variables else [], - tools=source_config.tools.copy() if source_config.tools else {}, + tools=source_config.tools.copy() if source_config.tools else [], is_active=True, created_at=now, updated_at=now, @@ -879,7 +879,7 @@ class AppService: # if data.variables is not None: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: - agent_cfg.tools = storage_data.get("tools", {}) + agent_cfg.tools = storage_data.get("tools", []) agent_cfg.updated_at = now @@ -966,7 +966,7 @@ class AppService: "max_history": 10 }, variables=[], - tools={}, + tools=[], is_active=True, created_at=now, updated_at=now, @@ -1183,7 +1183,7 @@ class AppService: "knowledge_retrieval": agent_cfg.knowledge_retrieval, "memory": agent_cfg.memory, "variables": agent_cfg.variables or [], - "tools": agent_cfg.tools or {}, + "tools": agent_cfg.tools or [], } # config = AgentConfigConverter.from_storage_format(agent_cfg) default_model_config_id = agent_cfg.default_model_config_id diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index eefc71c5..9a1dbd32 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,19 +10,22 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.models import AgentConfig, ModelApiKey, ModelConfig +from app.repositories.tool_repository import ToolRepository from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session +from app.services.tool_service import ToolService logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -291,24 +294,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search_config = agent_config.tools.get("web_search", {}) - web_search_enable = web_search_config.get("enabled", False) + tool_service = ToolService(self.db) - if web_search_enable: - logger.info("网络搜索已启用") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search_config) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + # 从配置中获取启用的工具 + if hasattr(agent_config, 'tools') and agent_config.tools: + for tool_config in agent_config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, str(workspace_id))) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -503,24 +504,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search_config = agent_config.tools.get("web_search", {}) - web_search_enable = web_search_config.get("enabled", False) + tool_service = ToolService(self.db) - if web_search_enable: - logger.info("网络搜索已启用") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search_config) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + # 从配置中获取启用的工具 + if hasattr(agent_config, 'tools') and agent_config.tools: + for tool_config in agent_config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, str(workspace_id))) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) # 添加知识库检索工具 if agent_config.knowledge_retrieval: