Files
MemoryBear/api/app/services/agent_tools.py
2025-12-20 16:03:06 +08:00

332 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent 发现和调用工具"""
import uuid
import time
import datetime
from typing import Optional, Dict, Any, List
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.models import AgentConfig, ModelConfig, AgentInvocation
from app.services.agent_registry import AgentRegistry
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.repositories import workspace_repository, knowledge_repository
logger = get_business_logger()
# ==================== Agent 发现工具 ====================
class AgentDiscoveryInput(BaseModel):
"""Agent 发现工具输入参数"""
query: Optional[str] = Field(None, description="搜索关键词,如:'客服''技术支持'")
domain: Optional[str] = Field(None, description="专业领域,如:'customer_service''technical_support'")
capabilities: Optional[List[str]] = Field(None, description="所需能力列表,如:['退货处理', '订单查询']")
def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID):
"""创建 Agent 发现工具
Args:
registry: Agent 注册表
workspace_id: 当前工作空间 ID
Returns:
Agent 发现工具
"""
@tool(args_schema=AgentDiscoveryInput)
def discover_agents(
query: Optional[str] = None,
domain: Optional[str] = None,
capabilities: Optional[List[str]] = None
) -> str:
"""发现系统中可用的 Agent。当需要找到能够处理特定任务的 Agent 时使用此工具。
Args:
query: 搜索关键词(如:"客服""技术支持"
domain: 专业领域(如:"customer_service""technical_support"
capabilities: 所需能力(如:["退货处理", "订单查询"]
Returns:
可用 Agent 的列表和描述
"""
try:
agents = registry.discover_agents(
query=query,
domain=domain,
capabilities=capabilities,
workspace_id=workspace_id
)
if not agents:
return "未找到匹配的 Agent"
# 格式化输出
result = f"找到 {len(agents)} 个可用的 Agent\n\n"
for i, agent in enumerate(agents, 1):
result += f"{i}. {agent['name']}\n"
result += f" ID: {agent['id']}\n"
if agent['description']:
result += f" 描述: {agent['description']}\n"
if agent['domain']:
result += f" 领域: {agent['domain']}\n"
if agent['capabilities']:
result += f" 能力: {', '.join(agent['capabilities'])}\n"
if agent['tools']:
result += f" 工具: {', '.join(agent['tools'])}\n"
result += "\n"
logger.info(
"Agent 发现成功",
extra={
"query": query,
"domain": domain,
"found_count": len(agents)
}
)
return result
except Exception as e:
logger.error("Agent 发现失败", extra={"error": str(e)})
return f"发现 Agent 失败: {str(e)}"
return discover_agents
# ==================== Agent 调用工具 ====================
class AgentInvocationInput(BaseModel):
"""Agent 调用工具输入参数"""
agent_id: str = Field(..., description="要调用的 Agent ID通过 discover_agents 工具获取)")
message: str = Field(..., description="发送给 Agent 的消息或任务描述")
context: Optional[Dict[str, Any]] = Field(None, description="可选的上下文信息(如:用户信息、历史记录等)")
def create_agent_invocation_tool(
db: Session,
registry: AgentRegistry,
workspace_id: uuid.UUID,
current_agent_id: uuid.UUID,
conversation_id: Optional[uuid.UUID] = None,
parent_invocation_id: Optional[uuid.UUID] = None,
invocation_chain: Optional[List[uuid.UUID]] = None
):
"""创建 Agent 调用工具
Args:
db: 数据库会话
registry: Agent 注册表
workspace_id: 当前工作空间 ID
current_agent_id: 当前 Agent ID
conversation_id: 会话 ID
parent_invocation_id: 父调用 ID
invocation_chain: 调用链(用于检测循环调用)
Returns:
Agent 调用工具
"""
# 1. 获取工作空间的 storage_type
storage_type = 'neo4j' # 默认值
user_rag_memory_id = None
try:
workspace = workspace_repository.get_workspace_by_id(db, workspace_id)
if workspace and workspace.storage_type:
storage_type = workspace.storage_type
logger.debug(
"获取工作空间存储类型成功",
extra={
"workspace_id": str(workspace_id),
"storage_type": storage_type
}
)
except Exception as e:
logger.warning(
"获取工作空间存储类型失败,使用默认值 neo4j",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
# 2. 如果 storage_type 是 rag获取知识库 ID
if storage_type == 'rag':
try:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MEMORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
logger.debug(
"获取 RAG 知识库成功",
extra={
"workspace_id": str(workspace_id),
"knowledge_id": user_rag_memory_id
}
)
else:
logger.warning(
"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id)}
)
storage_type = 'neo4j'
except Exception as e:
logger.warning(
"获取 RAG 知识库失败,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
storage_type = 'neo4j'
if invocation_chain is None:
invocation_chain = []
@tool(args_schema=AgentInvocationInput)
async def invoke_agent(
agent_id: str,
message: str,
context: Optional[Dict[str, Any]] = None
) -> str:
"""调用另一个 Agent 来处理任务。当当前 Agent 无法处理某个任务,需要其他专业 Agent 帮助时使用。
Args:
agent_id: 要调用的 Agent ID通过 discover_agents 工具获取)
message: 发送给 Agent 的消息或任务描述
context: 可选的上下文信息(如:用户信息、历史记录等)
Returns:
被调用 Agent 的响应结果
"""
try:
# 1. 验证 Agent 存在
agent_uuid = uuid.UUID(agent_id)
agent_info = registry.get_agent(agent_uuid)
if not agent_info:
return f"Agent {agent_id} 不存在"
# 2. 验证权限(同工作空间或公开)
if agent_info["workspace_id"] != str(workspace_id) and agent_info["visibility"] != "public":
return f"无权访问 Agent {agent_info['name']}"
# 3. 防止自己调用自己
if agent_id == str(current_agent_id):
return "不能调用自己"
# 4. 防止循环调用
if agent_uuid in invocation_chain:
return f"检测到循环调用:{agent_info['name']} 已在调用链中"
# 5. 检查调用深度
max_depth = 5
if len(invocation_chain) >= max_depth:
return f"调用深度超过限制(最大 {max_depth} 层)"
# 6. 获取 Agent 配置
agent_config = db.get(AgentConfig, agent_uuid)
if not agent_config:
return "Agent 配置不存在"
# 7. 获取模型配置
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
if not model_config:
return "Agent 模型配置不存在"
# 8. 创建调用记录
invocation = AgentInvocation(
caller_agent_id=current_agent_id,
callee_agent_id=agent_uuid,
conversation_id=conversation_id,
parent_invocation_id=parent_invocation_id,
input_message=message,
context=context,
status="running",
started_at=datetime.datetime.now()
)
db.add(invocation)
db.commit()
db.refresh(invocation)
logger.info(
"Agent 调用开始",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"depth": len(invocation_chain)
}
)
start_time = time.time()
try:
# 9. 调用 Agent
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
workspace_id=workspace_id,
variables=context or {},
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
elapsed_time = time.time() - start_time
# 10. 更新调用记录
invocation.status = "completed"
invocation.output_message = result["message"]
invocation.completed_at = datetime.datetime.now()
invocation.elapsed_time = elapsed_time
invocation.token_usage = result.get("usage", {})
db.commit()
logger.info(
"Agent 调用成功",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"elapsed_time": elapsed_time
}
)
return result["message"]
except Exception as e:
# 更新调用记录为失败
invocation.status = "failed"
invocation.error_message = str(e)
invocation.completed_at = datetime.datetime.now()
invocation.elapsed_time = time.time() - start_time
db.commit()
logger.error(
"Agent 调用失败",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"error": str(e)
}
)
raise
except Exception as e:
logger.error(
"Agent 调用异常",
extra={
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"error": str(e)
}
)
return f"调用 Agent 失败: {str(e)}"
return invoke_agent