feat(agent tool): add agent tool plugin

This commit is contained in:
谢俊男
2026-01-06 15:25:25 +08:00
parent 190155f438
commit 492401f9b7
11 changed files with 349 additions and 90 deletions

View File

@@ -10,6 +10,8 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
@@ -40,6 +42,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
@@ -64,6 +67,20 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_id, tool_config in config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -83,21 +100,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
# 获取模型参数
model_parameters = config.model_parameters
@@ -170,6 +172,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
@@ -641,6 +644,20 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_id, tool_config in config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
@@ -660,21 +677,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})

View File

@@ -10,19 +10,22 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.services.tool_service import ToolService
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -291,24 +294,21 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id,
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -503,24 +503,21 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items():
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id,
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval: