feat(agent tool): add agent tool plugin
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -26,4 +27,9 @@ def get_workspace_list(
|
||||
):
|
||||
"""获取工作空间列表"""
|
||||
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号"""
|
||||
return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功")
|
||||
@@ -153,7 +153,8 @@ async def chat(
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -177,7 +178,8 @@ async def chat(
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
|
||||
@@ -97,8 +97,7 @@ class LangChainAgent:
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||
}
|
||||
)
|
||||
|
||||
@@ -139,8 +138,11 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
"""
|
||||
短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j
|
||||
"""
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
|
||||
@@ -164,6 +164,9 @@ class Settings:
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0")
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
|
||||
@@ -191,10 +191,14 @@ class BaseTool(ABC):
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
def to_langchain_tool(self, operation: Optional[str] = None):
|
||||
"""转换为Langchain工具格式
|
||||
|
||||
Args:
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
"""
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
return LangchainAdapter.convert_tool(self, operation)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
216
api/app/core/tools/builtin/operation_tool.py
Normal file
216
api/app/core/tools/builtin/operation_tool.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""操作工具 - 为特定操作创建的工具包装器"""
|
||||
from typing import List
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.models import ToolType
|
||||
|
||||
|
||||
class OperationTool(BaseTool):
|
||||
"""操作工具 - 包装基础工具的特定操作"""
|
||||
|
||||
def __init__(self, base_tool: BaseTool, operation: str):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"{self.base_tool.name}_{self.operation}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self.base_tool.description} - {self.operation}"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""返回特定操作的参数"""
|
||||
if self.base_tool.name == 'datetime_tool':
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
|
||||
def _get_datetime_params(self) -> List[ToolParameter]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if self.operation == "now":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "format":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "convert_timezone":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "timestamp_to_datetime":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_json_params(self) -> List[ToolParameter]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
if self.operation == "insert":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_value",
|
||||
type=ParameterType.STRING,
|
||||
description="新值(用于insert操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "replace":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="old_text",
|
||||
type=ParameterType.STRING,
|
||||
description="要替换的原文本(用于replace操作)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_text",
|
||||
type=ParameterType.STRING,
|
||||
description="替换后的新文本(用于replace操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "delete":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "parse":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return base_params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
kwargs["operation"] = self.operation
|
||||
return await self.base_tool.execute(**kwargs)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""自定义工具基类"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
if isinstance(self.schema_content, str):
|
||||
try:
|
||||
self.schema_content = json.loads(self.schema_content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
name=tool_instance.name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
_tool_instance=tool_instance,
|
||||
tool_instance=tool_instance,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 执行内部工具
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
result = await self.tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
@@ -73,24 +73,39 @@ class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
if operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||
# 为特定操作创建工具
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
else:
|
||||
# 单个工具
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
@@ -110,7 +125,7 @@ class LangchainAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
@@ -169,9 +184,10 @@ class LangchainAdapter:
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__module__": __name__,
|
||||
"__annotations__": annotations,
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
"model_config": {"extra": "forbid"},
|
||||
**fields
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,17 @@ logger = get_business_logger()
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
def __init__(self, db: Session = None):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
db: 数据库会话(可选)
|
||||
"""
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
if db:
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
else:
|
||||
self.connection_pool = None
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
@@ -592,7 +595,7 @@ class MCPServiceManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
|
||||
@@ -10,6 +10,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
@@ -40,6 +42,7 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
@@ -64,6 +67,20 @@ class AppChatService:
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
for tool_id, tool_config in config.tools.items():
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
||||
if tool_instance:
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -83,21 +100,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.tools
|
||||
# web_search_choice = web_tools.get("web_search", {})
|
||||
# web_search_enable = web_search_choice.get("enabled", False)
|
||||
# if web_search == True:
|
||||
# if web_search_enable == True:
|
||||
# search_tool = create_web_search_tool({})
|
||||
# tools.append(search_tool)
|
||||
#
|
||||
# logger.debug(
|
||||
# "已添加网络搜索工具",
|
||||
# extra={
|
||||
# "tool_count": len(tools)
|
||||
# }
|
||||
# )
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -170,6 +172,7 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
@@ -641,6 +644,20 @@ class AppChatService:
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
for tool_id, tool_config in config.tools.items():
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
||||
if tool_instance:
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
@@ -660,21 +677,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
|
||||
@@ -10,19 +10,22 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -291,24 +294,21 @@ class DraftRunService:
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加网络搜索工具
|
||||
if web_search:
|
||||
if agent_config.tools:
|
||||
web_search_config = agent_config.tools.get("web_search", {})
|
||||
web_search_enable = web_search_config.get("enabled", False)
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if web_search_enable:
|
||||
logger.info("网络搜索已启用")
|
||||
# 创建网络搜索工具
|
||||
search_tool = create_web_search_tool(web_search_config)
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_id, tool_config in agent_config.tools.items():
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(
|
||||
tool_config.get("config", {}).get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -503,24 +503,21 @@ class DraftRunService:
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加网络搜索工具
|
||||
if web_search:
|
||||
if agent_config.tools:
|
||||
web_search_config = agent_config.tools.get("web_search", {})
|
||||
web_search_enable = web_search_config.get("enabled", False)
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if web_search_enable:
|
||||
logger.info("网络搜索已启用")
|
||||
# 创建网络搜索工具
|
||||
search_tool = create_web_search_tool(web_search_config)
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_id, tool_config in agent_config.tools.items():
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(
|
||||
tool_config.get("config", {}).get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
|
||||
Reference in New Issue
Block a user