feat(agent tool): add agent tool plugin
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -26,4 +27,9 @@ def get_workspace_list(
|
|||||||
):
|
):
|
||||||
"""获取工作空间列表"""
|
"""获取工作空间列表"""
|
||||||
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
||||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||||
|
|
||||||
|
@router.get("/version", response_model=ApiResponse)
|
||||||
|
def get_system_version():
|
||||||
|
"""获取系统版本号"""
|
||||||
|
return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功")
|
||||||
@@ -153,7 +153,8 @@ async def chat(
|
|||||||
config=agent_config,
|
config=agent_config,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -177,7 +178,8 @@ async def chat(
|
|||||||
web_search=web_search,
|
web_search=web_search,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -97,8 +97,7 @@ class LangChainAgent:
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||||
"tool_count": len(self.tools)
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,8 +138,11 @@ class LangChainAgent:
|
|||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
"""
|
||||||
|
短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j
|
||||||
|
"""
|
||||||
end_user_end=f"Term_{end_user_end}"
|
end_user_end=f"Term_{end_user_end}"
|
||||||
print(messages)
|
print(messages)
|
||||||
print(aimessages)
|
print(aimessages)
|
||||||
|
|||||||
@@ -164,6 +164,9 @@ class Settings:
|
|||||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||||
|
|
||||||
|
# official environment system version
|
||||||
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0")
|
||||||
|
|
||||||
def get_memory_output_path(self, filename: str = "") -> str:
|
def get_memory_output_path(self, filename: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -191,10 +191,14 @@ class BaseTool(ABC):
|
|||||||
execution_time=execution_time
|
execution_time=execution_time
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_langchain_tool(self):
|
def to_langchain_tool(self, operation: Optional[str] = None):
|
||||||
"""转换为Langchain工具格式"""
|
"""转换为Langchain工具格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: 特定操作(适用于有操作的工具)
|
||||||
|
"""
|
||||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||||
return LangchainAdapter.convert_tool(self)
|
return LangchainAdapter.convert_tool(self, operation)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||||
216
api/app/core/tools/builtin/operation_tool.py
Normal file
216
api/app/core/tools/builtin/operation_tool.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""操作工具 - 为特定操作创建的工具包装器"""
|
||||||
|
from typing import List
|
||||||
|
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||||
|
from app.models import ToolType
|
||||||
|
|
||||||
|
|
||||||
|
class OperationTool(BaseTool):
|
||||||
|
"""操作工具 - 包装基础工具的特定操作"""
|
||||||
|
|
||||||
|
def __init__(self, base_tool: BaseTool, operation: str):
|
||||||
|
self.base_tool = base_tool
|
||||||
|
self.operation = operation
|
||||||
|
super().__init__(base_tool.tool_id, base_tool.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f"{self.base_tool.name}_{self.operation}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_type(self) -> ToolType:
|
||||||
|
"""工具类型"""
|
||||||
|
return ToolType.BUILTIN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return f"{self.base_tool.description} - {self.operation}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> List[ToolParameter]:
|
||||||
|
"""返回特定操作的参数"""
|
||||||
|
if self.base_tool.name == 'datetime_tool':
|
||||||
|
return self._get_datetime_params()
|
||||||
|
elif self.base_tool.name == 'json_tool':
|
||||||
|
return self._get_json_params()
|
||||||
|
else:
|
||||||
|
# 默认返回除operation外的所有参数
|
||||||
|
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||||
|
|
||||||
|
def _get_datetime_params(self) -> List[ToolParameter]:
|
||||||
|
"""获取datetime_tool特定操作的参数"""
|
||||||
|
if self.operation == "now":
|
||||||
|
return [
|
||||||
|
ToolParameter(
|
||||||
|
name="to_timezone",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
required=False,
|
||||||
|
default="Asia/Shanghai"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="output_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "format":
|
||||||
|
return [
|
||||||
|
ToolParameter(
|
||||||
|
name="input_value",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入值(时间字符串或时间戳)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="input_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="output_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "convert_timezone":
|
||||||
|
return [
|
||||||
|
ToolParameter(
|
||||||
|
name="input_value",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入值(时间字符串或时间戳)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="input_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="output_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="from_timezone",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="源时区(如:UTC, Asia/Shanghai)",
|
||||||
|
required=False,
|
||||||
|
default="Asia/Shanghai"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="to_timezone",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
required=False,
|
||||||
|
default="Asia/Shanghai"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "timestamp_to_datetime":
|
||||||
|
return [
|
||||||
|
ToolParameter(
|
||||||
|
name="input_value",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入值(时间字符串或时间戳)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="output_format",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
required=False,
|
||||||
|
default="%Y-%m-%d %H:%M:%S"
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="to_timezone",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
required=False,
|
||||||
|
default="Asia/Shanghai"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_json_params(self) -> List[ToolParameter]:
|
||||||
|
"""获取json_tool特定操作的参数"""
|
||||||
|
base_params = [
|
||||||
|
ToolParameter(
|
||||||
|
name="input_data",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.operation == "insert":
|
||||||
|
return base_params + [
|
||||||
|
ToolParameter(
|
||||||
|
name="json_path",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="new_value",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="新值(用于insert操作)",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "replace":
|
||||||
|
return base_params + [
|
||||||
|
ToolParameter(
|
||||||
|
name="json_path",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="old_text",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="要替换的原文本(用于replace操作)",
|
||||||
|
required=True
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name="new_text",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="替换后的新文本(用于replace操作)",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "delete":
|
||||||
|
return base_params + [
|
||||||
|
ToolParameter(
|
||||||
|
name="json_path",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.operation == "parse":
|
||||||
|
return base_params + [
|
||||||
|
ToolParameter(
|
||||||
|
name="json_path",
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return base_params
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> ToolResult:
|
||||||
|
"""执行特定操作"""
|
||||||
|
# 添加operation参数
|
||||||
|
kwargs["operation"] = self.operation
|
||||||
|
return await self.base_tool.execute(**kwargs)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""自定义工具基类"""
|
"""自定义工具基类"""
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
|
|||||||
|
|
||||||
if not self.schema_content:
|
if not self.schema_content:
|
||||||
return operations
|
return operations
|
||||||
|
|
||||||
|
if isinstance(self.schema_content, str):
|
||||||
|
try:
|
||||||
|
self.schema_content = json.loads(self.schema_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
|
||||||
|
return operations
|
||||||
|
|
||||||
paths = self.schema_content.get("paths", {})
|
paths = self.schema_content.get("paths", {})
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
|||||||
name=tool_instance.name,
|
name=tool_instance.name,
|
||||||
description=tool_instance.description,
|
description=tool_instance.description,
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
_tool_instance=tool_instance,
|
tool_instance=tool_instance,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
|||||||
"""异步执行工具"""
|
"""异步执行工具"""
|
||||||
try:
|
try:
|
||||||
# 执行内部工具
|
# 执行内部工具
|
||||||
result = await self._tool_instance.safe_execute(**kwargs)
|
result = await self.tool_instance.safe_execute(**kwargs)
|
||||||
|
|
||||||
# 转换结果为Langchain格式
|
# 转换结果为Langchain格式
|
||||||
return LangchainAdapter._format_result_for_langchain(result)
|
return LangchainAdapter._format_result_for_langchain(result)
|
||||||
@@ -73,24 +73,39 @@ class LangchainAdapter:
|
|||||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
|
||||||
"""将内部工具转换为Langchain工具
|
"""将内部工具转换为Langchain工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool: 内部工具实例
|
tool: 内部工具实例
|
||||||
|
operation: 特定操作(适用于有操作的工具)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Langchain兼容的工具包装器
|
Langchain兼容的工具包装器
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
if operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
# 为特定操作创建工具
|
||||||
return wrapper
|
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||||
|
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||||
|
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||||
|
return wrapper
|
||||||
|
else:
|
||||||
|
# 单个工具
|
||||||
|
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||||
|
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||||
|
return wrapper
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||||
|
"""为特定操作创建工具实例"""
|
||||||
|
from app.core.tools.builtin.operation_tool import OperationTool
|
||||||
|
return OperationTool(base_tool, operation)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||||
"""批量转换工具
|
"""批量转换工具
|
||||||
@@ -110,7 +125,7 @@ class LangchainAdapter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||||
|
|
||||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
|
||||||
return converted_tools
|
return converted_tools
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -169,9 +184,10 @@ class LangchainAdapter:
|
|||||||
"ToolArgsSchema",
|
"ToolArgsSchema",
|
||||||
(BaseModel,),
|
(BaseModel,),
|
||||||
{
|
{
|
||||||
|
"__module__": __name__,
|
||||||
"__annotations__": annotations,
|
"__annotations__": annotations,
|
||||||
**fields,
|
"model_config": {"extra": "forbid"},
|
||||||
"Config": type("Config", (), {"extra": "forbid"})
|
**fields
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,14 +16,17 @@ logger = get_business_logger()
|
|||||||
class MCPServiceManager:
|
class MCPServiceManager:
|
||||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session = None):
|
||||||
"""初始化MCP服务管理器
|
"""初始化MCP服务管理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话(可选)
|
||||||
"""
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
if db:
|
||||||
|
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||||
|
else:
|
||||||
|
self.connection_pool = None
|
||||||
|
|
||||||
# 服务状态管理
|
# 服务状态管理
|
||||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||||
@@ -592,7 +595,7 @@ class MCPServiceManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理失效服务失败: {e}")
|
logger.error(f"清理失效服务失败: {e}")
|
||||||
|
|
||||||
def get_manager_status(self) -> Dict[str, Any]:
|
def get_manager_status(self) -> Dict[str, Any]:
|
||||||
"""获取管理器状态"""
|
"""获取管理器状态"""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.services.tool_service import ToolService
|
||||||
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import MultiAgentConfig, AgentConfig
|
from app.models import MultiAgentConfig, AgentConfig
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
@@ -40,6 +42,7 @@ class AppChatService:
|
|||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
workspace_id: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
|
|
||||||
@@ -64,6 +67,20 @@ class AppChatService:
|
|||||||
|
|
||||||
# 准备工具列表
|
# 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
|
# 获取工具服务
|
||||||
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
|
# 从配置中获取启用的工具
|
||||||
|
if hasattr(config, 'tools') and config.tools:
|
||||||
|
for tool_id, tool_config in config.tools.items():
|
||||||
|
if tool_config.get("enabled", False):
|
||||||
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
||||||
|
if tool_instance:
|
||||||
|
# 转换为LangChain工具
|
||||||
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
@@ -83,21 +100,6 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
web_tools = config.tools
|
|
||||||
# web_search_choice = web_tools.get("web_search", {})
|
|
||||||
# web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
# if web_search == True:
|
|
||||||
# if web_search_enable == True:
|
|
||||||
# search_tool = create_web_search_tool({})
|
|
||||||
# tools.append(search_tool)
|
|
||||||
#
|
|
||||||
# logger.debug(
|
|
||||||
# "已添加网络搜索工具",
|
|
||||||
# extra={
|
|
||||||
# "tool_count": len(tools)
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
@@ -170,6 +172,7 @@ class AppChatService:
|
|||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
@@ -641,6 +644,20 @@ class AppChatService:
|
|||||||
|
|
||||||
# 准备工具列表
|
# 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
|
# 获取工具服务
|
||||||
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
|
# 从配置中获取启用的工具
|
||||||
|
if hasattr(config, 'tools') and config.tools:
|
||||||
|
for tool_id, tool_config in config.tools.items():
|
||||||
|
if tool_config.get("enabled", False):
|
||||||
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
||||||
|
if tool_instance:
|
||||||
|
# 转换为LangChain工具
|
||||||
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||||
@@ -660,21 +677,6 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
web_tools = config.get("tools")
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search == True:
|
|
||||||
if web_search_enable == True:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.get("model_parameters", {})
|
model_parameters = config.get("model_parameters", {})
|
||||||
|
|
||||||
|
|||||||
@@ -10,19 +10,22 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.tools import tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||||
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
from app.services.langchain_tool_server import Search
|
from app.services.langchain_tool_server import Search
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from langchain.tools import tool
|
from app.services.tool_service import ToolService
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
class KnowledgeRetrievalInput(BaseModel):
|
class KnowledgeRetrievalInput(BaseModel):
|
||||||
@@ -291,24 +294,21 @@ class DraftRunService:
|
|||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 添加网络搜索工具
|
tool_service = ToolService(self.db)
|
||||||
if web_search:
|
|
||||||
if agent_config.tools:
|
|
||||||
web_search_config = agent_config.tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_config.get("enabled", False)
|
|
||||||
|
|
||||||
if web_search_enable:
|
# 从配置中获取启用的工具
|
||||||
logger.info("网络搜索已启用")
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
# 创建网络搜索工具
|
for tool_id, tool_config in agent_config.tools.items():
|
||||||
search_tool = create_web_search_tool(web_search_config)
|
if tool_config.get("enabled", False):
|
||||||
tools.append(search_tool)
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||||
logger.debug(
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
"已添加网络搜索工具",
|
self.db, str(workspace_id)))
|
||||||
extra={
|
if tool_instance:
|
||||||
"tool_count": len(tools)
|
# 转换为LangChain工具
|
||||||
}
|
langchain_tool = tool_instance.to_langchain_tool(
|
||||||
)
|
tool_config.get("config", {}).get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
@@ -503,24 +503,21 @@ class DraftRunService:
|
|||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 添加网络搜索工具
|
tool_service = ToolService(self.db)
|
||||||
if web_search:
|
|
||||||
if agent_config.tools:
|
|
||||||
web_search_config = agent_config.tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_config.get("enabled", False)
|
|
||||||
|
|
||||||
if web_search_enable:
|
# 从配置中获取启用的工具
|
||||||
logger.info("网络搜索已启用")
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
# 创建网络搜索工具
|
for tool_id, tool_config in agent_config.tools.items():
|
||||||
search_tool = create_web_search_tool(web_search_config)
|
if tool_config.get("enabled", False):
|
||||||
tools.append(search_tool)
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||||
logger.debug(
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
"已添加网络搜索工具",
|
self.db, str(workspace_id)))
|
||||||
extra={
|
if tool_instance:
|
||||||
"tool_count": len(tools)
|
# 转换为LangChain工具
|
||||||
}
|
langchain_tool = tool_instance.to_langchain_tool(
|
||||||
)
|
tool_config.get("config", {}).get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
|
|||||||
Reference in New Issue
Block a user