feat(agent tool): add agent tool plugin

This commit is contained in:
谢俊男
2026-01-06 15:25:25 +08:00
parent 190155f438
commit 492401f9b7
11 changed files with 349 additions and 90 deletions

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
@@ -26,4 +27,9 @@ def get_workspace_list(
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")
return success(data=workspace_list, msg="工作空间列表获取成功")
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号"""
return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功")

View File

@@ -153,7 +153,8 @@ async def chat(
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -177,7 +178,8 @@ async def chat(
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:

View File

@@ -97,8 +97,7 @@ class LangChainAgent:
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
"tool_names": [tool.name for tool in self.tools] if self.tools else []
}
)
@@ -139,8 +138,11 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
"""
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)

View File

@@ -164,6 +164,9 @@ class Settings:
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0")
def get_memory_output_path(self, filename: str = "") -> str:
"""

View File

@@ -191,10 +191,14 @@ class BaseTool(ABC):
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
def to_langchain_tool(self, operation: Optional[str] = None):
"""转换为Langchain工具格式
Args:
operation: 特定操作(适用于有操作的工具)
"""
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
return LangchainAdapter.convert_tool(self, operation)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,216 @@
"""操作工具 - 为特定操作创建的工具包装器"""
from typing import List
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.models import ToolType
class OperationTool(BaseTool):
"""操作工具 - 包装基础工具的特定操作"""
def __init__(self, base_tool: BaseTool, operation: str):
self.base_tool = base_tool
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
def description(self) -> str:
return f"{self.base_tool.description} - {self.operation}"
@property
def parameters(self) -> List[ToolParameter]:
"""返回特定操作的参数"""
if self.base_tool.name == 'datetime_tool':
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
def _get_datetime_params(self) -> List[ToolParameter]:
"""获取datetime_tool特定操作的参数"""
if self.operation == "now":
return [
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "format":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "convert_timezone":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
elif self.operation == "timestamp_to_datetime":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
def _get_json_params(self) -> List[ToolParameter]:
"""获取json_tool特定操作的参数"""
base_params = [
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
)
]
if self.operation == "insert":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert操作",
required=True
)
]
elif self.operation == "replace":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=True
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=True
)
]
elif self.operation == "delete":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
elif self.operation == "parse":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
else:
return base_params
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数
kwargs["operation"] = self.operation
return await self.base_tool.execute(**kwargs)

View File

@@ -1,4 +1,5 @@
"""自定义工具基类"""
import json
import time
from typing import Dict, Any, List, Optional
import aiohttp
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
if not self.schema_content:
return operations
if isinstance(self.schema_content, str):
try:
self.schema_content = json.loads(self.schema_content)
except json.JSONDecodeError:
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
return operations
paths = self.schema_content.get("paths", {})

View File

@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
tool_instance=tool_instance,
**kwargs
)
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
"""异步执行工具"""
try:
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
result = await self.tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
@@ -73,24 +73,39 @@ class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
operation: 特定操作(适用于有操作的工具)
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
if operation and tool.name in ['datetime_tool', 'json_tool']:
# 为特定操作创建工具
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
else:
# 单个工具
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
"""为特定操作创建工具实例"""
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
@@ -110,7 +125,7 @@ class LangchainAdapter:
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
return converted_tools
@staticmethod
@@ -169,9 +184,10 @@ class LangchainAdapter:
"ToolArgsSchema",
(BaseModel,),
{
"__module__": __name__,
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
"model_config": {"extra": "forbid"},
**fields
}
)

View File

@@ -16,14 +16,17 @@ logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
def __init__(self, db: Session = None):
"""初始化MCP服务管理器
Args:
db: 数据库会话
db: 数据库会话(可选)
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
if db:
self.connection_pool = MCPConnectionPool(max_connections=20)
else:
self.connection_pool = None
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
@@ -592,7 +595,7 @@ class MCPServiceManager:
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {

View File

@@ -10,6 +10,8 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
@@ -40,6 +42,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
@@ -64,6 +67,20 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_id, tool_config in config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -83,21 +100,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
# 获取模型参数
model_parameters = config.model_parameters
@@ -170,6 +172,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
@@ -641,6 +644,20 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_id, tool_config in config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
@@ -660,21 +677,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})

View File

@@ -10,19 +10,22 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.services.tool_service import ToolService
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -291,24 +294,21 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id,
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -503,24 +503,21 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id,
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval: