549 lines
19 KiB
Python
549 lines
19 KiB
Python
"""Agent 发现和调用工具"""
|
||
import uuid
|
||
import time
|
||
import datetime
|
||
from typing import Optional, Dict, Any, List
|
||
from langchain.tools import tool
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.models import AgentConfig, ModelConfig, AgentInvocation
|
||
from app.services.agent_registry import AgentRegistry
|
||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||
from app.core.error_codes import BizCode
|
||
from app.core.logging_config import get_business_logger
|
||
from app.repositories import workspace_repository, knowledge_repository
|
||
from app.core.tools.registry import ToolRegistry
|
||
from app.core.tools.executor import ToolExecutor
|
||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||
TOOL_MANAGEMENT_AVAILABLE = True
|
||
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
# ==================== Agent 发现工具 ====================
|
||
|
||
class AgentDiscoveryInput(BaseModel):
|
||
"""Agent 发现工具输入参数"""
|
||
query: Optional[str] = Field(None, description="搜索关键词,如:'客服'、'技术支持'")
|
||
domain: Optional[str] = Field(None, description="专业领域,如:'customer_service'、'technical_support'")
|
||
capabilities: Optional[List[str]] = Field(None, description="所需能力列表,如:['退货处理', '订单查询']")
|
||
|
||
|
||
def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID):
|
||
"""创建 Agent 发现工具
|
||
|
||
Args:
|
||
registry: Agent 注册表
|
||
workspace_id: 当前工作空间 ID
|
||
|
||
Returns:
|
||
Agent 发现工具
|
||
"""
|
||
|
||
@tool(args_schema=AgentDiscoveryInput)
|
||
def discover_agents(
|
||
query: Optional[str] = None,
|
||
domain: Optional[str] = None,
|
||
capabilities: Optional[List[str]] = None
|
||
) -> str:
|
||
"""发现系统中可用的 Agent。当需要找到能够处理特定任务的 Agent 时使用此工具。
|
||
|
||
Args:
|
||
query: 搜索关键词(如:"客服"、"技术支持")
|
||
domain: 专业领域(如:"customer_service"、"technical_support")
|
||
capabilities: 所需能力(如:["退货处理", "订单查询"])
|
||
|
||
Returns:
|
||
可用 Agent 的列表和描述
|
||
"""
|
||
try:
|
||
agents = registry.discover_agents(
|
||
query=query,
|
||
domain=domain,
|
||
capabilities=capabilities,
|
||
workspace_id=workspace_id
|
||
)
|
||
|
||
if not agents:
|
||
return "未找到匹配的 Agent"
|
||
|
||
# 格式化输出
|
||
result = f"找到 {len(agents)} 个可用的 Agent:\n\n"
|
||
for i, agent in enumerate(agents, 1):
|
||
result += f"{i}. {agent['name']}\n"
|
||
result += f" ID: {agent['id']}\n"
|
||
if agent['description']:
|
||
result += f" 描述: {agent['description']}\n"
|
||
if agent['domain']:
|
||
result += f" 领域: {agent['domain']}\n"
|
||
if agent['capabilities']:
|
||
result += f" 能力: {', '.join(agent['capabilities'])}\n"
|
||
if agent['tools']:
|
||
result += f" 工具: {', '.join(agent['tools'])}\n"
|
||
result += "\n"
|
||
|
||
logger.info(
|
||
"Agent 发现成功",
|
||
extra={
|
||
"query": query,
|
||
"domain": domain,
|
||
"found_count": len(agents)
|
||
}
|
||
)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error("Agent 发现失败", extra={"error": str(e)})
|
||
return f"发现 Agent 失败: {str(e)}"
|
||
|
||
return discover_agents
|
||
|
||
|
||
# ==================== Agent 调用工具 ====================
|
||
|
||
class AgentInvocationInput(BaseModel):
|
||
"""Agent 调用工具输入参数"""
|
||
agent_id: str = Field(..., description="要调用的 Agent ID(通过 discover_agents 工具获取)")
|
||
message: str = Field(..., description="发送给 Agent 的消息或任务描述")
|
||
context: Optional[Dict[str, Any]] = Field(None, description="可选的上下文信息(如:用户信息、历史记录等)")
|
||
|
||
|
||
def create_agent_invocation_tool(
|
||
db: Session,
|
||
registry: AgentRegistry,
|
||
workspace_id: uuid.UUID,
|
||
current_agent_id: uuid.UUID,
|
||
conversation_id: Optional[uuid.UUID] = None,
|
||
parent_invocation_id: Optional[uuid.UUID] = None,
|
||
invocation_chain: Optional[List[uuid.UUID]] = None
|
||
):
|
||
"""创建 Agent 调用工具
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
registry: Agent 注册表
|
||
workspace_id: 当前工作空间 ID
|
||
current_agent_id: 当前 Agent ID
|
||
conversation_id: 会话 ID
|
||
parent_invocation_id: 父调用 ID
|
||
invocation_chain: 调用链(用于检测循环调用)
|
||
|
||
Returns:
|
||
Agent 调用工具
|
||
"""
|
||
# 1. 获取工作空间的 storage_type
|
||
storage_type = 'neo4j' # 默认值
|
||
user_rag_memory_id = None
|
||
|
||
try:
|
||
workspace = workspace_repository.get_workspace_by_id(db, workspace_id)
|
||
if workspace and workspace.storage_type:
|
||
storage_type = workspace.storage_type
|
||
logger.debug(
|
||
"获取工作空间存储类型成功",
|
||
extra={
|
||
"workspace_id": str(workspace_id),
|
||
"storage_type": storage_type
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"获取工作空间存储类型失败,使用默认值 neo4j",
|
||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||
)
|
||
|
||
# 2. 如果 storage_type 是 rag,获取知识库 ID
|
||
if storage_type == 'rag':
|
||
try:
|
||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||
db=db,
|
||
name="USER_RAG_MEMORY",
|
||
workspace_id=workspace_id
|
||
)
|
||
if knowledge:
|
||
user_rag_memory_id = str(knowledge.id)
|
||
logger.debug(
|
||
"获取 RAG 知识库成功",
|
||
extra={
|
||
"workspace_id": str(workspace_id),
|
||
"knowledge_id": user_rag_memory_id
|
||
}
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
|
||
extra={"workspace_id": str(workspace_id)}
|
||
)
|
||
storage_type = 'neo4j'
|
||
except Exception as e:
|
||
logger.warning(
|
||
"获取 RAG 知识库失败,将使用 neo4j 存储",
|
||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||
)
|
||
storage_type = 'neo4j'
|
||
|
||
if invocation_chain is None:
|
||
invocation_chain = []
|
||
|
||
@tool(args_schema=AgentInvocationInput)
|
||
async def invoke_agent(
|
||
agent_id: str,
|
||
message: str,
|
||
context: Optional[Dict[str, Any]] = None
|
||
) -> str:
|
||
"""调用另一个 Agent 来处理任务。当当前 Agent 无法处理某个任务,需要其他专业 Agent 帮助时使用。
|
||
|
||
Args:
|
||
agent_id: 要调用的 Agent ID(通过 discover_agents 工具获取)
|
||
message: 发送给 Agent 的消息或任务描述
|
||
context: 可选的上下文信息(如:用户信息、历史记录等)
|
||
|
||
Returns:
|
||
被调用 Agent 的响应结果
|
||
"""
|
||
try:
|
||
# 1. 验证 Agent 存在
|
||
agent_uuid = uuid.UUID(agent_id)
|
||
agent_info = registry.get_agent(agent_uuid)
|
||
if not agent_info:
|
||
return f"Agent {agent_id} 不存在"
|
||
|
||
# 2. 验证权限(同工作空间或公开)
|
||
if agent_info["workspace_id"] != str(workspace_id) and agent_info["visibility"] != "public":
|
||
return f"无权访问 Agent {agent_info['name']}"
|
||
|
||
# 3. 防止自己调用自己
|
||
if agent_id == str(current_agent_id):
|
||
return "不能调用自己"
|
||
|
||
# 4. 防止循环调用
|
||
if agent_uuid in invocation_chain:
|
||
return f"检测到循环调用:{agent_info['name']} 已在调用链中"
|
||
|
||
# 5. 检查调用深度
|
||
max_depth = 5
|
||
if len(invocation_chain) >= max_depth:
|
||
return f"调用深度超过限制(最大 {max_depth} 层)"
|
||
|
||
# 6. 获取 Agent 配置
|
||
agent_config = db.get(AgentConfig, agent_uuid)
|
||
if not agent_config:
|
||
return "Agent 配置不存在"
|
||
|
||
# 7. 获取模型配置
|
||
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
|
||
if not model_config:
|
||
return "Agent 模型配置不存在"
|
||
|
||
# 8. 创建调用记录
|
||
invocation = AgentInvocation(
|
||
caller_agent_id=current_agent_id,
|
||
callee_agent_id=agent_uuid,
|
||
conversation_id=conversation_id,
|
||
parent_invocation_id=parent_invocation_id,
|
||
input_message=message,
|
||
context=context,
|
||
status="running",
|
||
started_at=datetime.datetime.now()
|
||
)
|
||
db.add(invocation)
|
||
db.commit()
|
||
db.refresh(invocation)
|
||
|
||
logger.info(
|
||
"Agent 调用开始",
|
||
extra={
|
||
"invocation_id": str(invocation.id),
|
||
"caller_agent_id": str(current_agent_id),
|
||
"callee_agent_id": agent_id,
|
||
"depth": len(invocation_chain)
|
||
}
|
||
)
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 9. 调用 Agent
|
||
from app.services.draft_run_service import DraftRunService
|
||
draft_service = DraftRunService(db)
|
||
|
||
result = await draft_service.run(
|
||
agent_config=agent_config,
|
||
model_config=model_config,
|
||
message=message,
|
||
workspace_id=workspace_id,
|
||
variables=context or {},
|
||
storage_type=storage_type,
|
||
user_rag_memory_id=user_rag_memory_id
|
||
)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
# 10. 更新调用记录
|
||
invocation.status = "completed"
|
||
invocation.output_message = result["message"]
|
||
invocation.completed_at = datetime.datetime.now()
|
||
invocation.elapsed_time = elapsed_time
|
||
invocation.token_usage = result.get("usage", {})
|
||
db.commit()
|
||
|
||
logger.info(
|
||
"Agent 调用成功",
|
||
extra={
|
||
"invocation_id": str(invocation.id),
|
||
"caller_agent_id": str(current_agent_id),
|
||
"callee_agent_id": agent_id,
|
||
"elapsed_time": elapsed_time
|
||
}
|
||
)
|
||
|
||
return result["message"]
|
||
|
||
except Exception as e:
|
||
# 更新调用记录为失败
|
||
invocation.status = "failed"
|
||
invocation.error_message = str(e)
|
||
invocation.completed_at = datetime.datetime.now()
|
||
invocation.elapsed_time = time.time() - start_time
|
||
db.commit()
|
||
|
||
logger.error(
|
||
"Agent 调用失败",
|
||
extra={
|
||
"invocation_id": str(invocation.id),
|
||
"caller_agent_id": str(current_agent_id),
|
||
"callee_agent_id": agent_id,
|
||
"error": str(e)
|
||
}
|
||
)
|
||
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
"Agent 调用异常",
|
||
extra={
|
||
"caller_agent_id": str(current_agent_id),
|
||
"callee_agent_id": agent_id,
|
||
"error": str(e)
|
||
}
|
||
)
|
||
return f"调用 Agent 失败: {str(e)}"
|
||
|
||
return invoke_agent
|
||
|
||
def get_available_tools_for_agent(
|
||
db: Session,
|
||
workspace_id: uuid.UUID,
|
||
agent_id: Optional[uuid.UUID] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取Agent可用的工具列表
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workspace_id: 工作空间ID
|
||
agent_id: Agent ID(可选)
|
||
|
||
Returns:
|
||
可用工具列表
|
||
"""
|
||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||
logger.warning("工具管理系统不可用")
|
||
return []
|
||
|
||
try:
|
||
# 创建工具注册表
|
||
registry = ToolRegistry(db)
|
||
|
||
# 获取工具列表
|
||
tools = registry.list_tools(workspace_id=workspace_id)
|
||
|
||
# 转换为Agent可用的格式
|
||
available_tools = []
|
||
for tool_info in tools:
|
||
if tool_info.status.value == "active":
|
||
available_tools.append({
|
||
"id": tool_info.id,
|
||
"name": tool_info.name,
|
||
"description": tool_info.description,
|
||
"type": tool_info.tool_type.value,
|
||
"version": tool_info.version,
|
||
"tags": tool_info.tags,
|
||
"parameters": [
|
||
{
|
||
"name": param.name,
|
||
"type": param.type.value,
|
||
"description": param.description,
|
||
"required": param.required,
|
||
"default": param.default
|
||
}
|
||
for param in tool_info.parameters
|
||
]
|
||
})
|
||
|
||
logger.info(f"为Agent获取到 {len(available_tools)} 个可用工具")
|
||
return available_tools
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent可用工具失败: {e}")
|
||
return []
|
||
|
||
|
||
def create_langchain_tools_for_agent(
|
||
db: Session,
|
||
workspace_id: uuid.UUID,
|
||
agent_id: Optional[uuid.UUID] = None
|
||
) -> List[Any]:
|
||
"""为Agent创建Langchain兼容的工具列表
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workspace_id: 工作空间ID
|
||
agent_id: Agent ID(可选)
|
||
|
||
Returns:
|
||
Langchain工具列表
|
||
"""
|
||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||
logger.warning("工具管理系统不可用")
|
||
return []
|
||
|
||
try:
|
||
# 创建工具注册表
|
||
registry = ToolRegistry(db)
|
||
|
||
# 注册内置工具类
|
||
from app.core.tools.builtin import (
|
||
DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||
)
|
||
registry.register_tool_class(DateTimeTool)
|
||
registry.register_tool_class(JsonTool)
|
||
registry.register_tool_class(BaiduSearchTool)
|
||
registry.register_tool_class(MinerUTool)
|
||
registry.register_tool_class(TextInTool)
|
||
|
||
# 获取活跃的工具
|
||
tools = registry.list_tools(workspace_id=workspace_id)
|
||
active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||
|
||
# 转换为Langchain工具
|
||
langchain_tools = []
|
||
for tool_info in active_tools:
|
||
try:
|
||
tool_instance = registry.get_tool(tool_info.id)
|
||
if tool_instance:
|
||
langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||
langchain_tools.append(langchain_tool)
|
||
except Exception as e:
|
||
logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||
|
||
logger.info(f"为Agent创建了 {len(langchain_tools)} 个Langchain工具")
|
||
return langchain_tools
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建Agent Langchain工具失败: {e}")
|
||
return []
|
||
|
||
|
||
class ToolExecutionInput(BaseModel):
|
||
"""工具执行输入参数"""
|
||
tool_id: str = Field(..., description="工具ID")
|
||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||
timeout: Optional[float] = Field(None, description="超时时间(秒)")
|
||
|
||
|
||
def create_tool_execution_tool(
|
||
db: Session,
|
||
workspace_id: uuid.UUID,
|
||
user_id: uuid.UUID
|
||
):
|
||
"""创建工具执行工具
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workspace_id: 工作空间ID
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
工具执行工具
|
||
"""
|
||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||
logger.warning("工具管理系统不可用")
|
||
return None
|
||
|
||
@tool(args_schema=ToolExecutionInput)
|
||
async def execute_tool(
|
||
tool_id: str,
|
||
parameters: Dict[str, Any] = None,
|
||
timeout: Optional[float] = None
|
||
) -> str:
|
||
"""执行指定的工具。当需要使用系统中的工具来完成特定任务时使用。
|
||
|
||
Args:
|
||
tool_id: 工具ID(通过工具列表获取)
|
||
parameters: 工具参数(根据工具要求提供)
|
||
timeout: 超时时间(秒,可选)
|
||
|
||
Returns:
|
||
工具执行结果
|
||
"""
|
||
try:
|
||
# 创建工具执行器
|
||
registry = ToolRegistry(db)
|
||
executor = ToolExecutor(db, registry)
|
||
|
||
# 执行工具
|
||
result = await executor.execute_tool(
|
||
tool_id=tool_id,
|
||
parameters=parameters or {},
|
||
user_id=user_id,
|
||
workspace_id=workspace_id,
|
||
timeout=timeout
|
||
)
|
||
|
||
if result.success:
|
||
# 格式化成功结果
|
||
if isinstance(result.data, str):
|
||
return result.data
|
||
else:
|
||
import json
|
||
return json.dumps(result.data, ensure_ascii=False, indent=2)
|
||
else:
|
||
return f"工具执行失败: {result.error}"
|
||
|
||
except Exception as e:
|
||
logger.error(f"工具执行异常: {tool_id}, 错误: {e}")
|
||
return f"工具执行异常: {str(e)}"
|
||
|
||
return execute_tool
|
||
|
||
|
||
def get_tool_management_tools(
|
||
db: Session,
|
||
workspace_id: uuid.UUID,
|
||
user_id: uuid.UUID
|
||
) -> List[Any]:
|
||
"""获取工具管理相关的工具
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workspace_id: 工作空间ID
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
工具管理工具列表
|
||
"""
|
||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||
return []
|
||
|
||
tools = []
|
||
|
||
# 添加工具执行工具
|
||
execution_tool = create_tool_execution_tool(db, workspace_id, user_id)
|
||
if execution_tool:
|
||
tools.append(execution_tool)
|
||
|
||
return tools |