Merge #59 into develop from feature/20251219_xjn
feat(tool system): Tool system reengineering * feature/20251219_xjn: (2 commits) feat(tool system): tool system development feat(tool system): Tool system reengineering Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/59
This commit is contained in:
@@ -33,7 +33,6 @@ from . import (
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
tool_execution_controller,
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
@@ -71,6 +70,5 @@ manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(tool_execution_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -20,11 +20,10 @@ async def get_memory_info():
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/memory/{resource_id}/chat
|
||||
@router.post("/{resource_id}/chat")
|
||||
# /v1/memory/chat
|
||||
@router.post("/chat")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def chat_with_agent_demo(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -36,13 +35,12 @@ async def chat_with_agent_demo(
|
||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
|
||||
Args:
|
||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
message: 请求参数
|
||||
request: 声明请求
|
||||
api_key_auth: 包含验证后的API Key 信息
|
||||
db: db_session
|
||||
"""
|
||||
logger.info(f"API Key Auth: {api_key_auth}")
|
||||
logger.info(f"Resource ID: {resource_id}")
|
||||
logger.info(f"Resource ID: {api_key_auth.resource_id}")
|
||||
logger.info(f"Message: {message}")
|
||||
return success(data={"received": True}, msg="消息已接收")
|
||||
@@ -1,585 +1,250 @@
|
||||
"""工具管理API控制器"""
|
||||
import base64
|
||||
from typing import List, Optional, Dict, Any
|
||||
"""工具控制器 - 简化统一的工具管理接口"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from langfuse.api.core import jsonable_encoder
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from cryptography.fernet import Fernet
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["工具管理"])
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
def get_tool_service(db: Session = Depends(get_db)) -> ToolService:
|
||||
return ToolService(db)
|
||||
|
||||
|
||||
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""加密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
encrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
|
||||
else:
|
||||
encrypted_params[key] = value
|
||||
|
||||
return encrypted_params
|
||||
@router.get("/statistics", response_model=ApiResponse)
|
||||
async def get_tool_statistics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取工具统计信息"""
|
||||
try:
|
||||
stats = service.get_tool_statistics(current_user.tenant_id)
|
||||
return success(data=stats, msg="获取统计信息成功")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
decrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
try:
|
||||
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
|
||||
except Exception as e:
|
||||
decrypted_params[key] = value
|
||||
else:
|
||||
decrypted_params[key] = value
|
||||
|
||||
return decrypted_params
|
||||
|
||||
|
||||
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
|
||||
"""更新工具状态并返回新状态"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN:
|
||||
if not tool_info or not tool_info.get('requires_config', False):
|
||||
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
|
||||
elif not builtin_config or not builtin_config.parameters:
|
||||
new_status = ToolStatus.INACTIVE.value
|
||||
else:
|
||||
# 检查是否有必要的API密钥
|
||||
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
|
||||
else: # 自定义和MCP工具
|
||||
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
|
||||
|
||||
# 更新数据库中的状态
|
||||
if tool_config.status != new_status:
|
||||
tool_config.status = new_status
|
||||
|
||||
return new_status
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""工具列表响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
category: str
|
||||
version: str = "1.0.0"
|
||||
status: str # active inactive error loading
|
||||
requires_config: bool = False
|
||||
# is_configured: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class BuiltinToolConfigRequest(BaseModel):
|
||||
"""内置工具配置请求"""
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
|
||||
|
||||
class CustomToolCreateRequest(BaseModel):
|
||||
"""自定义工具创建请求体模型,包含参数校验规则"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
|
||||
description: str = Field(None, description="工具描述")
|
||||
base_url: str = Field(None, description="工具基础URL")
|
||||
schema_url: str = Field(None, description="工具Schema URL")
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选")
|
||||
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
|
||||
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
|
||||
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30")
|
||||
|
||||
# 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段
|
||||
@field_validator("auth_config")
|
||||
def validate_auth_config(cls, v, values):
|
||||
auth_type = values.data.get("auth_type")
|
||||
if auth_type == "api_key" and (not v or "api_key" not in v):
|
||||
raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段")
|
||||
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
|
||||
raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段")
|
||||
return v
|
||||
|
||||
class MCPToolCreateRequest(BaseModel):
|
||||
"""MCP工具创建请求体模型,适配MCP业务特性"""
|
||||
# 基础必填字段(带长度/格式校验)
|
||||
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
|
||||
description: str = Field(None, description="MCP工具描述")
|
||||
# MCP核心字段:服务端URL(强制HTTP/HTTPS格式)
|
||||
server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议")
|
||||
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
|
||||
connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典")
|
||||
|
||||
@field_validator("connection_config")
|
||||
def validate_connection_config(cls, v):
|
||||
# 示例1:若包含timeout,必须是1-300的整数
|
||||
if "timeout" in v:
|
||||
timeout = v["timeout"]
|
||||
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
|
||||
raise ValueError("connection_config.timeout必须是1-300的整数")
|
||||
return v
|
||||
|
||||
# @field_validator("server_url")
|
||||
# def validate_server_url_protocol(cls, v):
|
||||
# if v.scheme != "https":
|
||||
# raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)")
|
||||
# return v
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
@router.get("", response_model=List[ToolListResponse])
|
||||
@router.get("", response_model=ApiResponse)
|
||||
async def list_tools(
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
name: Optional[str] = Query(None),
|
||||
tool_type: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取工具列表(包含内置工具、自定义工具和MCP工具)"""
|
||||
"""获取工具列表"""
|
||||
try:
|
||||
# 初始化内置工具(如果需要)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.ensure_builtin_tools_initialized(
|
||||
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
|
||||
# 确保内置工具已初始化
|
||||
service.ensure_builtin_tools_initialized(current_user.tenant_id)
|
||||
|
||||
# 获取工具列表
|
||||
tools = service.list_tools(
|
||||
tenant_id=current_user.tenant_id,
|
||||
name=name,
|
||||
tool_type=ToolType(tool_type) if tool_type else None,
|
||||
status=ToolStatus(status) if status else None
|
||||
)
|
||||
return success(data=tools, msg="获取工具列表成功")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
response_tools = []
|
||||
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
@router.get("/{tool_id}", response_model=ApiResponse)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取工具详情"""
|
||||
tool = service.get_tool_info(tool_id, current_user.tenant_id)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=tool, msg="获取工具详情成功")
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
async def create_tool(
|
||||
request: ToolCreateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
tool_id = service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
icon=request.icon,
|
||||
description=request.description,
|
||||
config=request.config
|
||||
)
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type)
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
|
||||
tools = query.all()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
|
||||
for tool_config in tools:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
status = _update_tool_status(tool_config, builtin_config, tool_info)
|
||||
else:
|
||||
status = _update_tool_status(tool_config)
|
||||
|
||||
response_tools.append(ToolListResponse(
|
||||
id=str(tool_config.id),
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
tool_type=tool_config.tool_type,
|
||||
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
|
||||
version="1.0.0",
|
||||
status=status,
|
||||
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
|
||||
))
|
||||
|
||||
return response_tools
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}")
|
||||
async def get_builtin_tool_detail(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
@router.put("/{tool_id}", response_model=ApiResponse)
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取内置工具详情"""
|
||||
"""更新工具"""
|
||||
try:
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id
|
||||
).first()
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
|
||||
is_configured = False
|
||||
config_parameters = {}
|
||||
|
||||
if builtin_config and builtin_config.parameters:
|
||||
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
# 不返回敏感信息,只返回非敏感配置
|
||||
config_parameters = {k: v for k, v in builtin_config.parameters.items()
|
||||
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
|
||||
|
||||
return {
|
||||
"id": tool_config.id,
|
||||
"name": tool_config.name,
|
||||
"description": tool_config.description,
|
||||
"category": tool_info['category'],
|
||||
"status": tool_config.tool_type,
|
||||
"requires_config": tool_info['requires_config'],
|
||||
"is_configured": is_configured,
|
||||
"config_parameters": config_parameters
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/builtin/{tool_id}/configure")
|
||||
async def configure_builtin_tool(
|
||||
tool_id: str,
|
||||
request: BuiltinToolConfigRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""配置内置工具参数(租户级别)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 获取全局工具信息
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools_config = config_manager.load_builtin_tools_config()
|
||||
tool_info = None
|
||||
for tool_key, info in builtin_tools_config.items():
|
||||
if info['tool_class'] == builtin_config.tool_class:
|
||||
tool_info = info
|
||||
break
|
||||
|
||||
if not tool_info:
|
||||
raise HTTPException(status_code=404, detail="工具信息不存在")
|
||||
|
||||
# 加密敏感参数
|
||||
encrypted_params = _encrypt_sensitive_params(request.parameters)
|
||||
|
||||
# 更新配置
|
||||
builtin_config.parameters = encrypted_params
|
||||
|
||||
# 更新状态
|
||||
_update_tool_status(tool_config, builtin_config, tool_info)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool_config.name} 配置成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"配置内置工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}/config")
|
||||
async def get_builtin_tool_config(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具配置(用于使用)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 解密参数
|
||||
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
|
||||
|
||||
return {
|
||||
"tool_id": tool_id,
|
||||
"tool_class": builtin_config.tool_class,
|
||||
"name": tool_config.name,
|
||||
"parameters": decrypted_params,
|
||||
"status": tool_config.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/custom")
|
||||
async def create_custom_tool(
|
||||
request: CustomToolCreateRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建自定义工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "custom"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
tool_config = ToolConfig(
|
||||
success_flag = service.update_tool(
|
||||
tool_id=tool_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.CUSTOM,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
icon=request.icon,
|
||||
config=request.config,
|
||||
is_enabled=request.config.get("is_enabled", None)
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(msg="工具更新成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 创建CustomToolConfig记录
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
base_url=request.base_url,
|
||||
schema_url=request.schema_url,
|
||||
schema_content=request.schema_content,
|
||||
auth_type=request.auth_type,
|
||||
auth_config=request.auth_config,
|
||||
|
||||
@router.delete("/{tool_id}", response_model=ApiResponse)
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(msg="工具删除成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution/execute", response_model=ApiResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecuteRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""执行工具"""
|
||||
try:
|
||||
result = await service.execute_tool(
|
||||
tool_id=request.tool_id,
|
||||
parameters=request.parameters,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
timeout=request.timeout
|
||||
)
|
||||
db.add(custom_config)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"自定义工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
return success(
|
||||
data={
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage
|
||||
},
|
||||
msg="工具执行完成"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建自定义工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/mcp")
|
||||
async def create_mcp_tool(
|
||||
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
|
||||
|
||||
@router.post("/parse_schema", response_model=ApiResponse)
|
||||
async def parse_openapi_schema(
|
||||
request: ParseSchemaRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""创建MCP工具"""
|
||||
"""解析OpenAPI schema"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "mcp"
|
||||
result = await service.parse_openapi_schema(request.schema_content, request.schema_url)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
return success(data=result, msg="Schema解析完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse)
|
||||
async def sync_mcp_tools(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""同步MCP工具列表"""
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 创建数据库记录
|
||||
try:
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.MCP,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
@router.post("/{tool_id}/test", response_model=ApiResponse)
|
||||
async def test_tool_connection(
|
||||
tool_id: str,
|
||||
test_request: Optional[CustomToolTestRequest] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
if test_request:
|
||||
# 自定义工具测试
|
||||
result = await service.test_custom_tool(
|
||||
tool_id, current_user.tenant_id,
|
||||
test_request.method, test_request.path, test_request.parameters
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建MCPToolConfig记录
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=request.server_url,
|
||||
connection_config=request.connection_config
|
||||
)
|
||||
db.add(mcp_config)
|
||||
|
||||
db.commit()
|
||||
except SQLAlchemyError as db_e:
|
||||
db.rollback()
|
||||
logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}",
|
||||
exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},"
|
||||
f"工具名:{request.name}):{str(db_e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"MCP工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建MCP工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许删除")
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 删除成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{tool_id}")
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许修改")
|
||||
|
||||
if config_data is not None:
|
||||
tool.config_data = config_data
|
||||
# 更新状态
|
||||
_update_tool_status(tool)
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 更新成功",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{tool_id}/toggle")
|
||||
async def toggle_tool_status(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换工具活跃/非活跃状态"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 在active和inactive之间切换
|
||||
if tool.status == ToolStatus.ACTIVE.value:
|
||||
tool.status = ToolStatus.INACTIVE.value
|
||||
elif tool.status == ToolStatus.INACTIVE.value:
|
||||
tool.status = ToolStatus.ACTIVE.value
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
# 普通连接测试
|
||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||
return success(data=result, msg="连接测试完成")
|
||||
except Exception as e:
|
||||
logger.error(f"切换工具状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/enums/tool_types", response_model=ApiResponse)
|
||||
async def get_tool_types():
|
||||
"""获取工具类型枚举"""
|
||||
return success(
|
||||
data=[
|
||||
{"value": ToolType.BUILTIN.value, "label": "内置工具"},
|
||||
{"value": ToolType.CUSTOM.value, "label": "自定义工具"},
|
||||
{"value": ToolType.MCP.value, "label": "MCP工具"}
|
||||
],
|
||||
msg="获取工具类型成功"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/enums/status", response_model=ApiResponse)
|
||||
async def get_tool_status():
|
||||
"""获取工具状态枚举"""
|
||||
return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功")
|
||||
|
||||
|
||||
@router.get("/auth/types", response_model=ApiResponse)
|
||||
async def get_auth_types():
|
||||
"""获取认证类型枚举"""
|
||||
return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功")
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
"""工具执行API控制器"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
|
||||
from app.core.tools.builtin import *
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""工具执行请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
|
||||
|
||||
|
||||
class BatchExecutionRequest(BaseModel):
|
||||
"""批量执行请求"""
|
||||
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
|
||||
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""工具执行响应"""
|
||||
success: bool
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error_code: Optional[str] = None
|
||||
execution_time: float
|
||||
token_usage: Optional[Dict[str, int]] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChainStepRequest(BaseModel):
|
||||
"""链步骤请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
condition: Optional[str] = Field(None, description="执行条件")
|
||||
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
|
||||
error_handling: str = Field("stop", description="错误处理策略")
|
||||
|
||||
|
||||
class ChainExecutionRequest(BaseModel):
|
||||
"""链执行请求"""
|
||||
name: str = Field(..., description="链名称")
|
||||
description: str = Field("", description="链描述")
|
||||
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
|
||||
execution_mode: str = Field("sequential", description="执行模式")
|
||||
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
|
||||
global_timeout: Optional[float] = Field(None, description="全局超时")
|
||||
|
||||
|
||||
class ExecutionHistoryResponse(BaseModel):
|
||||
"""执行历史响应"""
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
status: str
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
execution_time: Optional[float]
|
||||
user_id: Optional[str]
|
||||
workspace_id: Optional[str]
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Any]
|
||||
error_message: Optional[str]
|
||||
token_usage: Optional[Dict[str, int]]
|
||||
|
||||
|
||||
class ToolConnectionTestResponse(BaseModel):
|
||||
"""工具连接测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== 依赖注入 ====================
|
||||
|
||||
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
|
||||
"""获取工具注册表"""
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
# 注册内置工具类
|
||||
registry.register_tool_class(DateTimeTool)
|
||||
registry.register_tool_class(JsonTool)
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
db: Session = Depends(get_db),
|
||||
registry: ToolRegistry = Depends(get_tool_registry)
|
||||
) -> ToolExecutor:
|
||||
"""获取工具执行器"""
|
||||
return ToolExecutor(db, registry)
|
||||
|
||||
|
||||
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
|
||||
"""获取链管理器"""
|
||||
return ChainManager(executor)
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
|
||||
@router.post("/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""执行单个工具"""
|
||||
try:
|
||||
# 生成执行ID
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=request.tool_id,
|
||||
parameters=request.parameters,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
execution_id=execution_id,
|
||||
timeout=request.timeout,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_id,
|
||||
tool_id=request.tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[ToolExecutionResponse])
|
||||
async def execute_tools_batch(
|
||||
request: BatchExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""批量执行工具"""
|
||||
try:
|
||||
# 准备执行配置
|
||||
execution_configs = []
|
||||
execution_ids = []
|
||||
|
||||
for exec_request in request.executions:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
execution_ids.append(execution_id)
|
||||
|
||||
execution_configs.append({
|
||||
"tool_id": exec_request.tool_id,
|
||||
"parameters": exec_request.parameters,
|
||||
"user_id": current_user.id,
|
||||
"workspace_id": current_user.current_workspace_id,
|
||||
"execution_id": execution_id,
|
||||
"timeout": exec_request.timeout,
|
||||
"metadata": exec_request.metadata
|
||||
})
|
||||
|
||||
# 批量执行
|
||||
results = await executor.execute_tools_batch(
|
||||
execution_configs,
|
||||
max_concurrency=request.max_concurrency
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for i, result in enumerate(results):
|
||||
responses.append(ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_ids[i],
|
||||
tool_id=request.executions[i].tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/chain", response_model=Dict[str, Any])
|
||||
async def execute_tool_chain(
|
||||
request: ChainExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""执行工具链"""
|
||||
try:
|
||||
# 转换步骤格式
|
||||
steps = []
|
||||
for step_request in request.steps:
|
||||
step = ChainStep(
|
||||
tool_id=step_request.tool_id,
|
||||
parameters=step_request.parameters,
|
||||
condition=step_request.condition,
|
||||
output_mapping=step_request.output_mapping,
|
||||
error_handling=step_request.error_handling
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# 创建链定义
|
||||
chain_definition = ChainDefinition(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
steps=steps,
|
||||
execution_mode=ChainExecutionMode(request.execution_mode),
|
||||
global_timeout=request.global_timeout
|
||||
)
|
||||
|
||||
# 注册并执行链
|
||||
chain_manager.register_chain(chain_definition)
|
||||
|
||||
result = await chain_manager.execute_chain(
|
||||
chain_name=request.name,
|
||||
initial_variables=request.initial_variables
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_executions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取正在运行的执行"""
|
||||
try:
|
||||
running_executions = executor.get_running_executions()
|
||||
|
||||
# 过滤当前工作空间的执行
|
||||
workspace_executions = [
|
||||
exec_info for exec_info in running_executions
|
||||
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
|
||||
]
|
||||
|
||||
return workspace_executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
|
||||
async def cancel_execution(
|
||||
execution_id: str = Path(..., description="执行ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""取消工具执行"""
|
||||
try:
|
||||
success = await executor.cancel_execution(execution_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "执行已取消"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="执行不存在或已完成")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ExecutionHistoryResponse])
|
||||
async def get_execution_history(
|
||||
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行历史"""
|
||||
try:
|
||||
history = executor.get_execution_history(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for record in history:
|
||||
responses.append(ExecutionHistoryResponse(
|
||||
execution_id=record["execution_id"],
|
||||
tool_id=record["tool_id"],
|
||||
status=record["status"],
|
||||
started_at=record["started_at"],
|
||||
completed_at=record["completed_at"],
|
||||
execution_time=record["execution_time"],
|
||||
user_id=record["user_id"],
|
||||
workspace_id=record["workspace_id"],
|
||||
input_data=record["input_data"],
|
||||
output_data=record["output_data"],
|
||||
error_message=record["error_message"],
|
||||
token_usage=record["token_usage"]
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=Dict[str, Any])
|
||||
async def get_execution_statistics(
|
||||
days: int = Query(7, ge=1, le=90, description="统计天数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行统计"""
|
||||
try:
|
||||
stats = executor.get_execution_statistics(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""获取正在运行的工具链"""
|
||||
try:
|
||||
running_chains = chain_manager.get_running_chains()
|
||||
return running_chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中工具链失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains", response_model=List[Dict[str, Any]])
|
||||
async def list_tool_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""列出工具链"""
|
||||
try:
|
||||
chains = chain_manager.list_chains()
|
||||
return chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具链列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
|
||||
async def test_tool_connection(
|
||||
tool_id: str = Path(..., description="工具ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
result = await executor.test_tool_connection(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
|
||||
return ToolConnectionTestResponse(
|
||||
success=result.get("success", False),
|
||||
message=result.get("message", ""),
|
||||
error=result.get("error"),
|
||||
details=result.get("details")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
|
||||
return ToolConnectionTestResponse(
|
||||
success=False,
|
||||
message="连接测试失败",
|
||||
error=str(e)
|
||||
)
|
||||
@@ -33,10 +33,9 @@ def require_api_key(
|
||||
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
|
||||
|
||||
Usage:
|
||||
@router.get("/app/{resource_id}/chat")
|
||||
@router.get("/app/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -89,26 +88,6 @@ def require_api_key(
|
||||
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
||||
)
|
||||
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_id": str(api_key_obj.resource_id)
|
||||
}
|
||||
)
|
||||
|
||||
kwargs["api_key_auth"] = ApiKeyAuth(
|
||||
api_key_id=api_key_obj.id,
|
||||
workspace_id=api_key_obj.workspace_id,
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"""工具管理核心模块"""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolParameter
|
||||
from .registry import ToolRegistry
|
||||
from .executor import ToolExecutor
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
from .config_manager import ConfigManager
|
||||
from .chain_manager import ChainManager
|
||||
|
||||
# 可选导入,避免导入错误
|
||||
try:
|
||||
@@ -22,11 +18,7 @@ __all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolParameter",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
"LangchainAdapter",
|
||||
"ConfigManager",
|
||||
"ChainManager"
|
||||
"LangchainAdapter"
|
||||
]
|
||||
|
||||
# 只有在成功导入时才添加到__all__
|
||||
|
||||
@@ -1,98 +1,10 @@
|
||||
"""工具基础接口定义"""
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.models.tool_model import ToolType, ToolStatus
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型枚举"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""工具参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: ParameterType = Field(..., description="参数类型")
|
||||
description: str = Field("", description="参数描述")
|
||||
required: bool = Field(False, description="是否必需")
|
||||
default: Any = Field(None, description="默认值")
|
||||
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
||||
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
|
||||
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
|
||||
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果"""
|
||||
success: bool = Field(..., description="执行是否成功")
|
||||
data: Any = Field(None, description="返回数据")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
error_code: Optional[str] = Field(None, description="错误代码")
|
||||
execution_time: float = Field(..., description="执行时间(秒)")
|
||||
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
data: Any,
|
||||
execution_time: float,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建成功结果"""
|
||||
return cls(
|
||||
success=True,
|
||||
data=data,
|
||||
execution_time=execution_time,
|
||||
token_usage=token_usage,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error_result(
|
||||
cls,
|
||||
error: str,
|
||||
execution_time: float,
|
||||
error_code: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建错误结果"""
|
||||
return cls(
|
||||
success=False,
|
||||
error=error,
|
||||
error_code=error_code,
|
||||
execution_time=execution_time,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""工具信息"""
|
||||
id: str = Field(..., description="工具ID")
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
tool_type: ToolType = Field(..., description="工具类型")
|
||||
version: str = Field("1.0.0", description="工具版本")
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
from app.schemas.tool_schema import ToolParameter, ParameterType, ToolResult
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
@@ -107,7 +19,7 @@ class BaseTool(ABC):
|
||||
"""
|
||||
self.tool_id = tool_id
|
||||
self.config = config
|
||||
self._status = ToolStatus.ACTIVE
|
||||
self._status = ToolStatus.AVAILABLE
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -153,20 +65,6 @@ class BaseTool(ABC):
|
||||
"""工具标签"""
|
||||
return self.config.get("tags", [])
|
||||
|
||||
def get_info(self) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
return ToolInfo(
|
||||
id=self.tool_id,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
tool_type=self.tool_type,
|
||||
version=self.version,
|
||||
parameters=self.parameters,
|
||||
status=self.status,
|
||||
tags=self.tags,
|
||||
tenant_id=self.config.get("tenant_id")
|
||||
)
|
||||
|
||||
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""验证参数
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolResult, ToolParameter
|
||||
|
||||
|
||||
class BuiltinTool(BaseTool, ABC):
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import List
|
||||
import pytz
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
@@ -54,14 +54,14 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="calculation",
|
||||
@@ -106,10 +106,11 @@ class DateTimeTool(BuiltinTool):
|
||||
error_code="DATETIME_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _get_current_time(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _get_current_time(kwargs) -> dict:
|
||||
"""获取当前时间"""
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
timezone_str = kwargs.get("to_timezone", "Asia/Shanghai")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if timezone_str == "UTC":
|
||||
@@ -118,15 +119,20 @@ class DateTimeTool(BuiltinTool):
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
now = datetime.now(tz)
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": now.isoformat()
|
||||
"iso_format": now.isoformat(),
|
||||
"timestamp_ms": int(now.timestamp() * 1000),
|
||||
"utc_datetime": utc_now.strftime(output_format)
|
||||
}
|
||||
|
||||
def _format_datetime(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _format_datetime(kwargs) -> dict:
|
||||
"""格式化时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -144,8 +150,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _convert_timezone(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _convert_timezone(kwargs) -> dict:
|
||||
"""时区转换"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -184,8 +191,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"converted_timezone": to_timezone,
|
||||
"timestamp": int(converted_dt.timestamp())
|
||||
}
|
||||
|
||||
def _timestamp_to_datetime(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_datetime(kwargs) -> dict:
|
||||
"""时间戳转日期时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -196,6 +204,8 @@ class DateTimeTool(BuiltinTool):
|
||||
|
||||
# 转换时间戳
|
||||
timestamp = float(input_value)
|
||||
if timestamp > 1e12:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
@@ -211,8 +221,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timezone": timezone_str,
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _datetime_to_timestamp(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _datetime_to_timestamp(kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -240,7 +251,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
"""时间计算"""
|
||||
input_value = kwargs.get("input_value")
|
||||
@@ -278,8 +289,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(calculated_dt.timestamp())
|
||||
}
|
||||
|
||||
def _parse_time_delta(self, calculation: str) -> timedelta:
|
||||
|
||||
@staticmethod
|
||||
def _parse_time_delta(calculation: str) -> timedelta:
|
||||
"""解析时间计算表达式"""
|
||||
import re
|
||||
|
||||
|
||||
@@ -121,8 +121,9 @@ class JsonTool(BuiltinTool):
|
||||
error_code="JSON_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""格式化JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
@@ -151,12 +152,13 @@ class JsonTool(BuiltinTool):
|
||||
"sort_keys": sort_keys
|
||||
}
|
||||
}
|
||||
|
||||
def _minify_json(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _minify_json(input_data: str) -> Dict[str, Any]:
|
||||
"""压缩JSON"""
|
||||
# 解析并压缩
|
||||
data = json.loads(input_data)
|
||||
minified = json.dumps(data, separators=(',', ':'))
|
||||
minified = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
@@ -165,7 +167,7 @@ class JsonTool(BuiltinTool):
|
||||
"minified_json": minified,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
|
||||
def _validate_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""验证JSON"""
|
||||
try:
|
||||
@@ -190,17 +192,19 @@ class JsonTool(BuiltinTool):
|
||||
"size": len(input_data)
|
||||
}
|
||||
|
||||
def _convert_json(self, input_data: str) -> Dict[str, Any]:
|
||||
@staticmethod
|
||||
def _convert_json(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转义"""
|
||||
data = json.loads(input_data)
|
||||
converted = json.dumps(data, ensure_ascii=False)
|
||||
converted = json.dumps(data, ensure_ascii=True, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"converted_json": converted,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _json_to_yaml(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转YAML"""
|
||||
data = json.loads(input_data)
|
||||
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
@@ -212,8 +216,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(yaml_output),
|
||||
"converted_data": yaml_output
|
||||
}
|
||||
|
||||
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _yaml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""YAML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
@@ -228,10 +233,11 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _json_to_xml(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转XML"""
|
||||
data = json.loads(input_data)
|
||||
json_data = json.loads(input_data)
|
||||
|
||||
def dict_to_xml(data, root_name="root"):
|
||||
"""递归转换字典为XML"""
|
||||
@@ -267,7 +273,7 @@ class JsonTool(BuiltinTool):
|
||||
root.text = str(data)
|
||||
return root
|
||||
|
||||
xml_element = dict_to_xml(data)
|
||||
xml_element = dict_to_xml(json_data)
|
||||
xml_string = ET.tostring(xml_element, encoding='unicode')
|
||||
|
||||
# 格式化XML
|
||||
@@ -284,8 +290,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(formatted_xml),
|
||||
"converted_data": formatted_xml
|
||||
}
|
||||
|
||||
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _xml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""XML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
|
||||
@@ -328,8 +335,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _merge_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并JSON"""
|
||||
merge_data = kwargs.get("merge_data")
|
||||
if not merge_data:
|
||||
@@ -364,8 +372,9 @@ class JsonTool(BuiltinTool):
|
||||
"result_size": len(merged_json),
|
||||
"merged_data": merged_json
|
||||
}
|
||||
|
||||
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_path( input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取JSON路径"""
|
||||
json_path = kwargs.get("json_path")
|
||||
if not json_path:
|
||||
|
||||
@@ -275,8 +275,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_formula_result( result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化公式识别结果"""
|
||||
formulas = result.get("formulas", [])
|
||||
|
||||
@@ -288,8 +289,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_table_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化表格识别结果"""
|
||||
tables = result.get("tables", [])
|
||||
|
||||
@@ -301,8 +303,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_document_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化文档识别结果"""
|
||||
return {
|
||||
"recognition_mode": "document",
|
||||
@@ -314,8 +317,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _group_lines_to_paragraphs(lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将行分组为段落"""
|
||||
paragraphs = []
|
||||
current_paragraph = []
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.core.tools.base import ToolResult
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ChainExecutionMode(str, Enum):
|
||||
"""链执行模式"""
|
||||
SEQUENTIAL = "sequential" # 顺序执行
|
||||
PARALLEL = "parallel" # 并行执行
|
||||
CONDITIONAL = "conditional" # 条件执行
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainStep:
|
||||
"""链步骤定义"""
|
||||
tool_id: str
|
||||
parameters: Dict[str, Any]
|
||||
condition: Optional[str] = None # 执行条件
|
||||
output_mapping: Optional[Dict[str, str]] = None # 输出映射
|
||||
error_handling: str = "stop" # 错误处理:stop, continue, retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainDefinition:
|
||||
"""工具链定义"""
|
||||
name: str
|
||||
description: str
|
||||
steps: List[ChainStep]
|
||||
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
|
||||
global_timeout: Optional[float] = None
|
||||
retry_policy: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChainExecutionContext:
|
||||
"""链执行上下文"""
|
||||
|
||||
def __init__(self, chain_id: str):
|
||||
self.chain_id = chain_id
|
||||
self.variables: Dict[str, Any] = {}
|
||||
self.step_results: Dict[int, ToolResult] = {}
|
||||
self.current_step = 0
|
||||
self.is_completed = False
|
||||
self.is_failed = False
|
||||
self.error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ChainManager:
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
|
||||
def __init__(self, executor: ToolExecutor):
|
||||
"""初始化工具链管理器
|
||||
|
||||
Args:
|
||||
executor: 工具执行器
|
||||
"""
|
||||
self.executor = executor
|
||||
self._chains: Dict[str, ChainDefinition] = {}
|
||||
self._running_chains: Dict[str, ChainExecutionContext] = {}
|
||||
|
||||
def register_chain(self, chain: ChainDefinition) -> bool:
|
||||
"""注册工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 验证工具链定义
|
||||
validation_result = self._validate_chain(chain)
|
||||
if not validation_result[0]:
|
||||
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
|
||||
return False
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"工具链注册成功: {chain.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def unregister_chain(self, chain_name: str) -> bool:
|
||||
"""注销工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
if chain_name in self._chains:
|
||||
del self._chains[chain_name]
|
||||
logger.info(f"工具链注销成功: {chain_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_chains(self) -> List[Dict[str, Any]]:
|
||||
"""列出所有工具链
|
||||
|
||||
Returns:
|
||||
工具链信息列表
|
||||
"""
|
||||
chains = []
|
||||
for name, chain in self._chains.items():
|
||||
chains.append({
|
||||
"name": name,
|
||||
"description": chain.description,
|
||||
"step_count": len(chain.steps),
|
||||
"execution_mode": chain.execution_mode.value,
|
||||
"global_timeout": chain.global_timeout
|
||||
})
|
||||
|
||||
return chains
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None,
|
||||
chain_id: Optional[str] = None
|
||||
) -> Dict[str, Any] | None:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
initial_variables: 初始变量
|
||||
chain_id: 链执行ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if chain_name not in self._chains:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具链不存在: {chain_name}",
|
||||
"chain_id": chain_id
|
||||
}
|
||||
|
||||
chain = self._chains[chain_name]
|
||||
|
||||
# 生成链ID
|
||||
if not chain_id:
|
||||
import uuid
|
||||
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ChainExecutionContext(chain_id)
|
||||
context.variables = initial_variables or {}
|
||||
self._running_chains[chain_id] = context
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
|
||||
|
||||
# 根据执行模式执行
|
||||
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
|
||||
result = await self._execute_sequential(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
|
||||
result = await self._execute_parallel(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
|
||||
result = await self._execute_conditional(chain, context)
|
||||
else:
|
||||
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
|
||||
|
||||
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"chain_id": chain_id,
|
||||
"completed_steps": context.current_step,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
if chain_id in self._running_chains:
|
||||
del self._running_chains[chain_id]
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""顺序执行工具链"""
|
||||
for i, step in enumerate(chain.steps):
|
||||
context.current_step = i
|
||||
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
logger.debug(f"跳过步骤 {i}: 条件不满足")
|
||||
continue
|
||||
|
||||
# 准备参数
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
|
||||
context.step_results[i] = result
|
||||
|
||||
# 处理输出映射
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 处理执行失败
|
||||
if not result.success:
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = result.error
|
||||
break
|
||||
elif step.error_handling == "continue":
|
||||
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
|
||||
continue
|
||||
elif step.error_handling == "retry":
|
||||
# 简单重试逻辑
|
||||
retry_result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
context.step_results[i] = retry_result
|
||||
if not retry_result.success and step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = retry_result.error
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"步骤 {i} 执行异常: {e}")
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
break
|
||||
|
||||
context.is_completed = not context.is_failed
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": context.current_step + 1,
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""并行执行工具链"""
|
||||
# 准备所有步骤的执行配置
|
||||
execution_configs = []
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
continue
|
||||
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
execution_configs.append({
|
||||
"step_index": i,
|
||||
"tool_id": step.tool_id,
|
||||
"parameters": parameters
|
||||
})
|
||||
|
||||
# 并行执行所有步骤
|
||||
try:
|
||||
results = await self.executor.execute_tools_batch(execution_configs)
|
||||
|
||||
# 处理结果
|
||||
for i, result in enumerate(results):
|
||||
step_index = execution_configs[i]["step_index"]
|
||||
context.step_results[step_index] = result
|
||||
|
||||
# 处理输出映射
|
||||
step = chain.steps[step_index]
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 检查是否有失败的步骤
|
||||
failed_steps = [i for i, result in context.step_results.items() if not result.success]
|
||||
|
||||
context.is_completed = len(failed_steps) == 0
|
||||
if failed_steps:
|
||||
context.error_message = f"步骤 {failed_steps} 执行失败"
|
||||
|
||||
except Exception as e:
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": len(context.step_results),
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""条件执行工具链"""
|
||||
# 条件执行类似于顺序执行,但更严格地检查条件
|
||||
return await self._execute_sequential(chain, context)
|
||||
|
||||
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具链定义
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not chain.name:
|
||||
return False, "工具链名称不能为空"
|
||||
|
||||
if not chain.steps:
|
||||
return False, "工具链必须包含至少一个步骤"
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
if not step.tool_id:
|
||||
return False, f"步骤 {i} 缺少工具ID"
|
||||
|
||||
if step.error_handling not in ["stop", "continue", "retry"]:
|
||||
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
|
||||
|
||||
return True, None
|
||||
|
||||
def _prepare_parameters(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""准备参数(支持变量替换)
|
||||
|
||||
Args:
|
||||
parameters: 原始参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
处理后的参数
|
||||
"""
|
||||
prepared = {}
|
||||
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_name = value[2:-1]
|
||||
if var_name in context.variables:
|
||||
prepared[key] = context.variables[var_name]
|
||||
else:
|
||||
prepared[key] = value # 保持原值
|
||||
else:
|
||||
prepared[key] = value
|
||||
|
||||
return prepared
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str,
|
||||
context: ChainExecutionContext
|
||||
) -> bool:
|
||||
"""评估执行条件
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
条件是否满足
|
||||
"""
|
||||
try:
|
||||
# 简单的条件评估(可以扩展为更复杂的表达式解析)
|
||||
# 支持格式:variable == value, variable != value, variable > value 等
|
||||
|
||||
if "==" in condition:
|
||||
var_name, expected_value = condition.split("==", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) == expected_value
|
||||
|
||||
elif "!=" in condition:
|
||||
var_name, expected_value = condition.split("!=", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) != expected_value
|
||||
|
||||
elif condition in context.variables:
|
||||
# 简单的布尔检查
|
||||
return bool(context.variables[condition])
|
||||
|
||||
else:
|
||||
# 默认为真
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"条件评估失败: {condition}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _apply_output_mapping(
|
||||
self,
|
||||
mapping: Dict[str, str],
|
||||
output_data: Any,
|
||||
context: ChainExecutionContext
|
||||
):
|
||||
"""应用输出映射
|
||||
|
||||
Args:
|
||||
mapping: 输出映射配置
|
||||
output_data: 输出数据
|
||||
context: 执行上下文
|
||||
"""
|
||||
try:
|
||||
if isinstance(output_data, dict):
|
||||
for source_key, target_var in mapping.items():
|
||||
if source_key in output_data:
|
||||
context.variables[target_var] = output_data[source_key]
|
||||
else:
|
||||
# 如果输出不是字典,将整个输出映射到指定变量
|
||||
if "result" in mapping:
|
||||
context.variables[mapping["result"]] = output_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出映射失败: {e}")
|
||||
|
||||
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
|
||||
"""序列化工具结果
|
||||
|
||||
Args:
|
||||
result: 工具结果
|
||||
|
||||
Returns:
|
||||
序列化的结果
|
||||
"""
|
||||
return {
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
def get_running_chains(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的工具链
|
||||
|
||||
Returns:
|
||||
运行中的工具链列表
|
||||
"""
|
||||
chains = []
|
||||
for chain_id, context in self._running_chains.items():
|
||||
chains.append({
|
||||
"chain_id": chain_id,
|
||||
"current_step": context.current_step,
|
||||
"is_completed": context.is_completed,
|
||||
"is_failed": context.is_failed,
|
||||
"variables_count": len(context.variables),
|
||||
"completed_steps": len(context.step_results)
|
||||
})
|
||||
|
||||
return chains
|
||||
@@ -1,264 +0,0 @@
|
||||
"""工具配置管理器 - 管理工具配置的加载和验证"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolConfigSchema(BaseModel):
|
||||
"""工具配置基础Schema"""
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
version: str = "1.0.0"
|
||||
enabled: bool = True
|
||||
parameters: Dict[str, Any] = {}
|
||||
tags: list[str] = []
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class BuiltinToolConfigSchema(ToolConfigSchema):
|
||||
"""内置工具配置Schema"""
|
||||
tool_class: str
|
||||
tool_type: str = "builtin"
|
||||
|
||||
|
||||
class CustomToolConfigSchema(ToolConfigSchema):
|
||||
"""自定义工具配置Schema"""
|
||||
schema_url: Optional[str] = None
|
||||
schema_content: Optional[Dict[str, Any]] = None
|
||||
auth_type: str = "none"
|
||||
auth_config: Dict[str, Any] = {}
|
||||
base_url: Optional[str] = None
|
||||
timeout: int = 30
|
||||
tool_type: str = "custom"
|
||||
|
||||
|
||||
class MCPToolConfigSchema(ToolConfigSchema):
|
||||
"""MCP工具配置Schema"""
|
||||
server_url: str
|
||||
connection_config: Dict[str, Any] = {}
|
||||
available_tools: list[str] = []
|
||||
tool_type: str = "mcp"
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""工具配置管理器"""
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""初始化配置管理器
|
||||
|
||||
Args:
|
||||
config_dir: 配置文件目录,默认使用系统配置
|
||||
"""
|
||||
self.config_dir = Path(config_dir or self._get_default_config_dir())
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
|
||||
|
||||
def _get_default_config_dir(self) -> str:
|
||||
"""获取默认配置目录"""
|
||||
# 获取tools目录下的configs子目录
|
||||
tools_dir = Path(__file__).parent
|
||||
return str(tools_dir / "configs")
|
||||
|
||||
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
|
||||
"""加载内置工具配置
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
configs = {}
|
||||
builtin_dir = self.config_dir / "builtin"
|
||||
|
||||
if not builtin_dir.exists():
|
||||
logger.info("内置工具配置目录不存在,创建默认配置")
|
||||
self._create_default_builtin_configs(builtin_dir)
|
||||
|
||||
for config_file in builtin_dir.glob("*.json"):
|
||||
try:
|
||||
config_data = self._load_config_file(config_file)
|
||||
config = BuiltinToolConfigSchema(**config_data)
|
||||
configs[config.name] = config
|
||||
logger.debug(f"加载内置工具配置: {config.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
|
||||
|
||||
return configs
|
||||
|
||||
def load_builtin_tools_config(self) -> Dict[str, Any]:
|
||||
"""加载全局内置工具配置(兼容原有接口)
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
config_file = self.config_dir / "builtin_tools.json"
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
|
||||
"""确保内置工具已初始化到数据库
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
db_session: 数据库会话
|
||||
tool_config_model: ToolConfig模型类
|
||||
builtin_tool_config_model: BuiltinToolConfig模型类
|
||||
tool_type_enum: ToolType枚举
|
||||
tool_status_enum: ToolStatus枚举
|
||||
"""
|
||||
# 检查是否已初始化
|
||||
existing_count = db_session.query(tool_config_model).filter(
|
||||
tool_config_model.tenant_id == tenant_id,
|
||||
tool_config_model.tool_type == tool_type_enum.BUILTIN
|
||||
).count()
|
||||
|
||||
if existing_count > 0:
|
||||
return # 已初始化
|
||||
|
||||
# 加载全局配置
|
||||
builtin_tools = self.load_builtin_tools_config()
|
||||
|
||||
# 为租户创建内置工具记录
|
||||
for tool_key, tool_info in builtin_tools.items():
|
||||
# 设置初始状态
|
||||
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
|
||||
|
||||
tool_config = tool_config_model(
|
||||
name=tool_info['name'],
|
||||
description=tool_info['description'],
|
||||
tool_type=tool_type_enum.BUILTIN,
|
||||
tenant_id=tenant_id,
|
||||
status=initial_status
|
||||
)
|
||||
db_session.add(tool_config)
|
||||
db_session.flush()
|
||||
|
||||
builtin_config = builtin_tool_config_model(
|
||||
id=tool_config.id,
|
||||
tool_class=tool_info['tool_class'],
|
||||
parameters={}
|
||||
)
|
||||
db_session.add(builtin_config)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
|
||||
|
||||
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
|
||||
"""保存工具配置
|
||||
|
||||
Args:
|
||||
config: 工具配置
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
config_dir = self.config_dir / tool_type
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config_file = config_dir / f"{config.name}.json"
|
||||
config_data = config.model_dump()
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
|
||||
"""删除工具配置
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
config_file = self.config_dir / tool_type / f"{tool_name}.json"
|
||||
|
||||
if config_file.exists():
|
||||
config_file.unlink()
|
||||
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具配置
|
||||
|
||||
Args:
|
||||
config_data: 配置数据
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
schema_map = {
|
||||
"builtin": BuiltinToolConfigSchema,
|
||||
"custom": CustomToolConfigSchema,
|
||||
"mcp": MCPToolConfigSchema
|
||||
}
|
||||
|
||||
schema_class = schema_map.get(tool_type)
|
||||
if not schema_class:
|
||||
return False, f"不支持的工具类型: {tool_type}"
|
||||
|
||||
# 验证配置
|
||||
schema_class(**config_data)
|
||||
return True, None
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
|
||||
return False, f"配置验证失败: {error_msg}"
|
||||
except Exception as e:
|
||||
return False, f"配置验证异常: {str(e)}"
|
||||
|
||||
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
|
||||
"""加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置数据字典
|
||||
"""
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def _create_default_builtin_configs(self, builtin_dir: Path):
|
||||
"""创建默认内置工具配置
|
||||
|
||||
Args:
|
||||
builtin_dir: 内置工具配置目录
|
||||
"""
|
||||
builtin_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
|
||||
# 配置文件已经通过其他方式创建,这里只需要确保目录存在
|
||||
@@ -54,7 +54,8 @@
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"base_url": {"type": "string", "description": "API地址", "default": "https://api.textin.com/v1"}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
@@ -51,8 +50,9 @@ class AuthManager:
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证认证配置时出错: {e}"
|
||||
|
||||
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
|
||||
@staticmethod
|
||||
def _validate_api_key_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证API Key认证配置
|
||||
|
||||
Args:
|
||||
@@ -79,8 +79,9 @@ class AuthManager:
|
||||
return False, "API Key位置必须是 header、query 或 cookie"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
|
||||
@staticmethod
|
||||
def _validate_bearer_token_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证Bearer Token认证配置
|
||||
|
||||
Args:
|
||||
@@ -135,9 +136,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"应用认证时出错: {e}")
|
||||
return url, headers, params
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _apply_api_key_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
@@ -176,9 +177,9 @@ class AuthManager:
|
||||
headers["Cookie"] = cookie_value
|
||||
|
||||
return url, headers, params
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _apply_bearer_token_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
@@ -260,8 +261,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"解密认证配置失败: {e}")
|
||||
return encrypted_config
|
||||
|
||||
def _encrypt_string(self, value: str, key: str) -> str:
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_string(value: str, key: str) -> str:
|
||||
"""加密字符串
|
||||
|
||||
Args:
|
||||
@@ -289,8 +291,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"加密字符串失败: {e}")
|
||||
return value
|
||||
|
||||
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_string(encrypted_value: str, key: str) -> str:
|
||||
"""解密字符串
|
||||
|
||||
Args:
|
||||
@@ -471,8 +474,9 @@ class AuthManager:
|
||||
"error": f"测试认证时出错: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def get_auth_config_template(auth_type: AuthType) -> Dict[str, Any]:
|
||||
"""获取认证配置模板
|
||||
|
||||
Args:
|
||||
@@ -498,8 +502,9 @@ class AuthManager:
|
||||
}
|
||||
|
||||
return templates.get(auth_type, {})
|
||||
|
||||
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def mask_sensitive_config(auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""遮蔽认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,7 +5,8 @@ import aiohttp
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from app.models.tool_model import ToolType, AuthType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -173,8 +174,9 @@ class CustomTool(BaseTool):
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
|
||||
|
||||
@staticmethod
|
||||
def _convert_openapi_type(openapi_type: str) -> ParameterType:
|
||||
"""转换OpenAPI类型到内部类型"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
@@ -239,8 +241,9 @@ class CustomTool(BaseTool):
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return headers
|
||||
|
||||
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _build_request_data(operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建请求数据"""
|
||||
if operation["method"] in ["POST", "PUT", "PATCH"]:
|
||||
request_body = operation.get("request_body")
|
||||
@@ -284,6 +287,7 @@ class CustomTool(BaseTool):
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析HTTP响应JSON失败: {str(e)}")
|
||||
return await response.text()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -88,8 +91,9 @@ class OpenAPISchemaParser:
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema内容失败: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _parse_content(content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析内容为字典
|
||||
|
||||
Args:
|
||||
@@ -101,7 +105,7 @@ class OpenAPISchemaParser:
|
||||
"""
|
||||
try:
|
||||
# 根据内容类型解析
|
||||
if 'json' in content_type:
|
||||
if 'application/json' in content_type:
|
||||
return json.loads(content)
|
||||
elif 'yaml' in content_type or 'yml' in content_type:
|
||||
return yaml.safe_load(content)
|
||||
@@ -228,8 +232,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_parameters(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取操作参数
|
||||
|
||||
Args:
|
||||
@@ -266,8 +271,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_request_body(operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""提取请求体信息
|
||||
|
||||
Args:
|
||||
@@ -298,8 +304,9 @@ class OpenAPISchemaParser:
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys())
|
||||
}
|
||||
|
||||
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_responses(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取响应信息
|
||||
|
||||
Args:
|
||||
@@ -331,8 +338,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_parameters(operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成工具参数定义
|
||||
|
||||
Args:
|
||||
@@ -396,7 +404,7 @@ class OpenAPISchemaParser:
|
||||
parameters.extend(all_params.values())
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证操作参数
|
||||
|
||||
@@ -447,8 +455,9 @@ class OpenAPISchemaParser:
|
||||
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
|
||||
|
||||
@staticmethod
|
||||
def _validate_parameter_type(value: Any, expected_type: str) -> bool:
|
||||
"""验证参数类型
|
||||
|
||||
Args:
|
||||
@@ -474,4 +483,7 @@ class OpenAPISchemaParser:
|
||||
if expected_python_type:
|
||||
return isinstance(value, expected_python_type)
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
SchemaParser = OpenAPISchemaParser
|
||||
@@ -1,501 +0,0 @@
|
||||
"""工具执行器 - 负责工具的实际调用和执行管理"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import ToolExecution, ExecutionStatus
|
||||
from app.core.tools.base import BaseTool, ToolResult
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ExecutionContext:
|
||||
"""执行上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_id: str,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.execution_id = execution_id
|
||||
self.tool_id = tool_id
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.timeout = timeout or 60.0 # 默认60秒超时
|
||||
self.metadata = metadata or {}
|
||||
self.started_at = datetime.now()
|
||||
self.completed_at: Optional[datetime] = None
|
||||
self.status = ExecutionStatus.PENDING
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器 - 使用langchain标准接口执行工具"""
|
||||
|
||||
def __init__(self, db: Session, registry: ToolRegistry):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
registry: 工具注册表
|
||||
"""
|
||||
self.db = db
|
||||
self.registry = registry
|
||||
self._running_executions: Dict[str, ExecutionContext] = {}
|
||||
self._execution_lock = asyncio.Lock()
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
execution_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
parameters: 工具参数
|
||||
user_id: 用户ID
|
||||
workspace_id: 工作空间ID
|
||||
execution_id: 执行ID(可选,自动生成)
|
||||
timeout: 超时时间(秒)
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 生成执行ID
|
||||
if not execution_id:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
execution_id=execution_id,
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
timeout=timeout,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
error_code="TOOL_NOT_FOUND",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 记录执行开始
|
||||
await self._record_execution_start(context, parameters)
|
||||
|
||||
# 执行工具
|
||||
result = await self._execute_with_timeout(tool, parameters, context)
|
||||
|
||||
# 记录执行完成
|
||||
await self._record_execution_complete(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
|
||||
|
||||
# 记录执行失败
|
||||
error_result = ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=time.time() - context.started_at.timestamp()
|
||||
)
|
||||
await self._record_execution_complete(context, error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
async with self._execution_lock:
|
||||
if execution_id in self._running_executions:
|
||||
del self._running_executions[execution_id]
|
||||
|
||||
async def execute_tools_batch(
|
||||
self,
|
||||
tool_executions: List[Dict[str, Any]],
|
||||
max_concurrency: int = 5
|
||||
) -> List[ToolResult]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_executions: 工具执行配置列表,每个包含tool_id和parameters
|
||||
max_concurrency: 最大并发数
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
|
||||
async with semaphore:
|
||||
return await self.execute_tool(
|
||||
tool_id=exec_config["tool_id"],
|
||||
parameters=exec_config.get("parameters", {}),
|
||||
user_id=exec_config.get("user_id"),
|
||||
workspace_id=exec_config.get("workspace_id"),
|
||||
timeout=exec_config.get("timeout"),
|
||||
metadata=exec_config.get("metadata")
|
||||
)
|
||||
|
||||
# 并发执行所有工具
|
||||
tasks = [execute_single(config) for config in tool_executions]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append(
|
||||
ToolResult.error_result(
|
||||
error=str(result),
|
||||
error_code="BATCH_EXECUTION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
"""取消工具执行
|
||||
|
||||
Args:
|
||||
execution_id: 执行ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
if execution_id not in self._running_executions:
|
||||
return False
|
||||
|
||||
context = self._running_executions[execution_id]
|
||||
context.status = ExecutionStatus.FAILED
|
||||
|
||||
# 更新数据库记录
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = ExecutionStatus.FAILED.value
|
||||
execution_record.error_message = "执行被取消"
|
||||
execution_record.completed_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"工具执行已取消: {execution_id}")
|
||||
return True
|
||||
|
||||
def get_running_executions(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的执行列表
|
||||
|
||||
Returns:
|
||||
执行信息列表
|
||||
"""
|
||||
executions = []
|
||||
for execution_id, context in self._running_executions.items():
|
||||
executions.append({
|
||||
"execution_id": execution_id,
|
||||
"tool_id": context.tool_id,
|
||||
"user_id": str(context.user_id) if context.user_id else None,
|
||||
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
|
||||
"started_at": context.started_at.isoformat(),
|
||||
"status": context.status.value,
|
||||
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
|
||||
})
|
||||
|
||||
return executions
|
||||
|
||||
async def _execute_with_timeout(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
parameters: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> ToolResult:
|
||||
"""带超时的工具执行
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
parameters: 参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
self._running_executions[context.execution_id] = context
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
|
||||
try:
|
||||
# 使用asyncio.wait_for实现超时控制
|
||||
result = await asyncio.wait_for(
|
||||
tool.safe_execute(**parameters),
|
||||
timeout=context.timeout
|
||||
)
|
||||
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
context.status = ExecutionStatus.TIMEOUT
|
||||
return ToolResult.error_result(
|
||||
error=f"工具执行超时({context.timeout}秒)",
|
||||
error_code="EXECUTION_TIMEOUT",
|
||||
execution_time=context.timeout
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
raise
|
||||
|
||||
async def _record_execution_start(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
parameters: Dict[str, Any]
|
||||
):
|
||||
"""记录执行开始"""
|
||||
try:
|
||||
execution_record = ToolExecution(
|
||||
execution_id=context.execution_id,
|
||||
tool_config_id=uuid.UUID(context.tool_id),
|
||||
status=ExecutionStatus.RUNNING.value,
|
||||
input_data=parameters,
|
||||
started_at=context.started_at,
|
||||
user_id=context.user_id,
|
||||
workspace_id=context.workspace_id
|
||||
)
|
||||
|
||||
self.db.add(execution_record)
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已创建: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
async def _record_execution_complete(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: ToolResult
|
||||
):
|
||||
"""记录执行完成"""
|
||||
try:
|
||||
context.completed_at = datetime.now()
|
||||
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == context.execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = (
|
||||
ExecutionStatus.COMPLETED.value if result.success
|
||||
else ExecutionStatus.FAILED.value
|
||||
)
|
||||
execution_record.output_data = result.data if result.success else None
|
||||
execution_record.error_message = result.error if not result.success else None
|
||||
execution_record.completed_at = context.completed_at
|
||||
execution_record.execution_time = result.execution_time
|
||||
execution_record.token_usage = result.token_usage
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已更新: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
def get_execution_history(
|
||||
self,
|
||||
tool_id: Optional[str] = None,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取执行历史
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID过滤
|
||||
user_id: 用户ID过滤
|
||||
workspace_id: 工作空间ID过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
执行历史列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolExecution).order_by(
|
||||
ToolExecution.started_at.desc()
|
||||
)
|
||||
|
||||
if tool_id:
|
||||
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
|
||||
|
||||
if user_id:
|
||||
query = query.filter(ToolExecution.user_id == user_id)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.limit(limit).all()
|
||||
|
||||
history = []
|
||||
for execution in executions:
|
||||
history.append({
|
||||
"execution_id": execution.execution_id,
|
||||
"tool_id": str(execution.tool_config_id),
|
||||
"status": execution.status,
|
||||
"started_at": execution.started_at.isoformat() if execution.started_at else None,
|
||||
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
|
||||
"execution_time": execution.execution_time,
|
||||
"user_id": str(execution.user_id) if execution.user_id else None,
|
||||
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
|
||||
"input_data": execution.input_data,
|
||||
"output_data": execution.output_data,
|
||||
"error_message": execution.error_message,
|
||||
"token_usage": execution.token_usage
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
def get_execution_statistics(
|
||||
self,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""获取执行统计信息
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.started_at >= start_date
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.all()
|
||||
|
||||
# 统计数据
|
||||
total_executions = len(executions)
|
||||
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
|
||||
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
|
||||
|
||||
# 平均执行时间
|
||||
completed_executions = [e for e in executions if e.execution_time is not None]
|
||||
avg_execution_time = (
|
||||
sum(e.execution_time for e in completed_executions) / len(completed_executions)
|
||||
if completed_executions else 0
|
||||
)
|
||||
|
||||
# 按工具统计
|
||||
tool_stats = {}
|
||||
for execution in executions:
|
||||
tool_id = str(execution.tool_config_id)
|
||||
if tool_id not in tool_stats:
|
||||
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
|
||||
|
||||
tool_stats[tool_id]["total"] += 1
|
||||
if execution.status == ExecutionStatus.COMPLETED.value:
|
||||
tool_stats[tool_id]["successful"] += 1
|
||||
elif execution.status == ExecutionStatus.FAILED.value:
|
||||
tool_stats[tool_id]["failed"] += 1
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_executions": total_executions,
|
||||
"successful_executions": successful_executions,
|
||||
"failed_executions": failed_executions,
|
||||
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
|
||||
"average_execution_time": avg_execution_time,
|
||||
"tool_statistics": tool_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
|
||||
from .mcp.client import MCPClient
|
||||
|
||||
tool_config = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.id == uuid.UUID(tool_id)
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
return {"success": False, "message": "工具不存在"}
|
||||
|
||||
if tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
|
||||
}
|
||||
except:
|
||||
await client.disconnect()
|
||||
return {"success": False, "message": "MCP功能测试失败"}
|
||||
else:
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
else:
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if tool and hasattr(tool, 'test_connection'):
|
||||
result = tool.test_connection()
|
||||
return {"success": result.get("success", False), "message": result.get("message", "")}
|
||||
return {"success": True, "message": "工具无需连接测试"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": "测试失败", "error": str(e)}
|
||||
@@ -4,7 +4,8 @@ from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
from .client import MCPClient
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@@ -134,11 +134,40 @@ class MCPClient:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
if username and password:
|
||||
import base64
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
@@ -190,6 +219,8 @@ class MCPClient:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
@@ -251,8 +282,9 @@ class MCPClient:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
@@ -327,7 +359,7 @@ class MCPClient:
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
@@ -372,10 +404,10 @@ class MCPClient:
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
@@ -424,9 +456,9 @@ class MCPClient:
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = time.time()
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
@@ -148,7 +148,7 @@ class MCPServiceManager:
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
last_health_check=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
@@ -410,7 +410,8 @@ class MCPServiceManager:
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
||||
ToolConfig.tool_type == ToolType.MCP.value
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
@@ -531,7 +532,7 @@ class MCPServiceManager:
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .base import BaseTool, ToolInfo
|
||||
from .custom.base import CustomTool
|
||||
from .mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化工具注册表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||||
self._lock = asyncio.Lock() # 异步锁
|
||||
|
||||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||||
"""注册工具类
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
class_name: 类名(可选,默认使用类的__name__)
|
||||
"""
|
||||
class_name = class_name or tool_class.__name__
|
||||
self._tool_classes[class_name] = tool_class
|
||||
logger.info(f"工具类已注册: {class_name}")
|
||||
|
||||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||||
"""注册工具实例到系统
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否已存在
|
||||
if tenant_id:
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
# 全局工具(内置工具)
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id.is_(None),
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||||
return False
|
||||
|
||||
# 创建工具配置
|
||||
tool_config = ToolConfig(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_type=tool.tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
version=tool.version,
|
||||
tags=tool.tags,
|
||||
config_data=tool.config
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush() # 获取ID
|
||||
|
||||
# 根据工具类型创建特定配置
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
builtin_config = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool.__class__.__name__,
|
||||
parameters=tool.config.get("parameters", {})
|
||||
)
|
||||
self.db.add(builtin_config)
|
||||
|
||||
elif tool.tool_type == ToolType.CUSTOM:
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
schema_url=tool.config.get("schema_url"),
|
||||
schema_content=tool.config.get("schema_content"),
|
||||
auth_type=tool.config.get("auth_type", "none"),
|
||||
auth_config=tool.config.get("auth_config", {}),
|
||||
base_url=tool.config.get("base_url"),
|
||||
timeout=tool.config.get("timeout", 30)
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool.tool_type == ToolType.MCP:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=tool.config.get("server_url"),
|
||||
connection_config=tool.config.get("connection_config", {}),
|
||||
available_tools=tool.config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 缓存工具实例
|
||||
tool.tool_id = str(tool_config.id)
|
||||
self._tools[str(tool_config.id)] = tool
|
||||
|
||||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> bool:
|
||||
"""从系统注销工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否存在
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 检查是否有正在执行的任务
|
||||
running_executions = self.db.query(ToolExecution).filter(
|
||||
and_(
|
||||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||||
)
|
||||
).count()
|
||||
|
||||
if running_executions > 0:
|
||||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||||
return False
|
||||
|
||||
# 删除工具配置(级联删除相关记录)
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 从缓存中移除
|
||||
if tool_id in self._tools:
|
||||
del self._tools[tool_id]
|
||||
|
||||
logger.info(f"工具注销成功: {tool_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||||
"""获取工具实例
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
工具实例,如果不存在返回None
|
||||
"""
|
||||
# 先从缓存获取
|
||||
if tool_id in self._tools:
|
||||
return self._tools[tool_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 根据工具类型加载实例
|
||||
tool_instance = self._load_tool_instance(tool_config)
|
||||
if tool_instance:
|
||||
self._tools[tool_id] = tool_instance
|
||||
return tool_instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""列出工具
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
tool_type: 工具类型过滤
|
||||
status: 工具状态过滤
|
||||
tags: 标签过滤
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
|
||||
# 应用过滤条件
|
||||
if tenant_id:
|
||||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||||
query = query.filter(
|
||||
or_(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tenant_id.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
|
||||
if status == ToolStatus.ACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == True)
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == False)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||||
|
||||
tool_configs = query.all()
|
||||
|
||||
# 转换为ToolInfo
|
||||
tool_infos = []
|
||||
for config in tool_configs:
|
||||
tool_info = ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version,
|
||||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||||
)
|
||||
|
||||
# 尝试获取参数信息
|
||||
tool_instance = self.get_tool(str(config.id))
|
||||
if tool_instance:
|
||||
tool_info.parameters = tool_instance.parameters
|
||||
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
return tool_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出工具失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||||
"""更新工具状态
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
if status == ToolStatus.ACTIVE:
|
||||
tool_config.is_enabled = True
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
tool_config.is_enabled = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新缓存中的工具状态
|
||||
if tool_id in self._tools:
|
||||
self._tools[tool_id].status = status
|
||||
|
||||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||||
"""从配置加载工具实例
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置
|
||||
|
||||
Returns:
|
||||
工具实例
|
||||
"""
|
||||
try:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
# 加载内置工具
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return tool_class(str(tool_config.id), config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 加载自定义工具
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"schema_url": custom_config.schema_url,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config,
|
||||
"base_url": custom_config.base_url,
|
||||
"timeout": custom_config.timeout,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return CustomTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入自定义工具模块: {e}")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
# 加载MCP工具
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config,
|
||||
"available_tools": mcp_config.available_tools,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return MCPTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入MCP工具模块: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
if tenant_id:
|
||||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||||
|
||||
total_tools = query.count()
|
||||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||||
|
||||
# 按类型统计
|
||||
type_stats = {}
|
||||
for tool_type in ToolType:
|
||||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||||
type_stats[tool_type.value] = count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"active_tools": active_tools,
|
||||
"inactive_tools": total_tools - active_tools,
|
||||
"by_type": type_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空工具缓存"""
|
||||
self._tools.clear()
|
||||
logger.info("工具缓存已清空")
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -19,10 +19,40 @@ class ToolType(StrEnum):
|
||||
|
||||
class ToolStatus(StrEnum):
|
||||
"""工具状态枚举"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
LOADING = "loading"
|
||||
AVAILABLE = "available" # 可用(已配置且已启用)
|
||||
UNCONFIGURED = "unconfigured" # 未配置
|
||||
CONFIGURED_DISABLED = "configured_disabled" # 已配置未启用
|
||||
ERROR = "error" # 错误状态
|
||||
|
||||
@classmethod
|
||||
def get_all_statuses(cls):
|
||||
"""获取所有工具状态"""
|
||||
return [status.value for status in cls]
|
||||
|
||||
@classmethod
|
||||
def get_all_statuses_with_labels(cls):
|
||||
"""获取所有工具状态及其文本描述"""
|
||||
return [
|
||||
{"value": cls.AVAILABLE.value, "label": "可用"},
|
||||
{"value": cls.UNCONFIGURED.value, "label": "未配置"},
|
||||
{"value": cls.CONFIGURED_DISABLED.value, "label": "已配置未启用"},
|
||||
{"value": cls.ERROR.value, "label": "错误状态"}
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def is_valid_status(cls, status):
|
||||
"""检查状态是否有效"""
|
||||
return status in cls._value2member_map_
|
||||
|
||||
@classmethod
|
||||
def get_active_statuses(cls):
|
||||
"""获取所有活跃状态"""
|
||||
return [cls.AVAILABLE.value]
|
||||
|
||||
@classmethod
|
||||
def get_inactive_statuses(cls):
|
||||
"""获取所有非活跃状态"""
|
||||
return [cls.UNCONFIGURED.value, cls.CONFIGURED_DISABLED.value, cls.ERROR.value]
|
||||
|
||||
|
||||
class AuthType(StrEnum):
|
||||
@@ -30,6 +60,27 @@ class AuthType(StrEnum):
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
BEARER_TOKEN = "bearer_token"
|
||||
BASIC_AUTH = "basic_auth"
|
||||
|
||||
@classmethod
|
||||
def get_all_types(cls):
|
||||
"""获取所有认证类型"""
|
||||
return [auth_type.value for auth_type in cls]
|
||||
|
||||
@classmethod
|
||||
def get_all_types_with_labels(cls):
|
||||
"""获取所有认证类型及其文本描述"""
|
||||
return [
|
||||
{"value": cls.NONE.value, "label": "无需认证"},
|
||||
{"value": cls.API_KEY.value, "label": "API Key"},
|
||||
{"value": cls.BEARER_TOKEN.value, "label": "Bearer Token"},
|
||||
{"value": cls.BASIC_AUTH.value, "label": "Basic Auth"}
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def is_valid_types(cls, auth_type):
|
||||
"""检查认证类型是否有效"""
|
||||
return auth_type in cls._value2member_map_
|
||||
|
||||
|
||||
class ExecutionStatus(StrEnum):
|
||||
@@ -48,13 +99,14 @@ class ToolConfig(Base):
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
description = Column(Text)
|
||||
icon = Column(String(255)) # 工具图标
|
||||
tool_type = Column(String(50), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
|
||||
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
|
||||
status = Column(String(50), default=ToolStatus.UNCONFIGURED.value, nullable=False, index=True) # 工具状态
|
||||
|
||||
# 工具特定配置(JSON格式存储)
|
||||
config_data = Column(JSON, default=dict)
|
||||
|
||||
|
||||
# 元数据
|
||||
version = Column(String(50), default="1.0.0")
|
||||
tags = Column(JSON, default=list) # 标签列表
|
||||
@@ -78,12 +130,14 @@ class BuiltinToolConfig(Base):
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
tool_class = Column(String(255), nullable=False) # 工具类名
|
||||
parameters = Column(JSON, default=dict) # 工具参数配置
|
||||
|
||||
is_enabled = Column(Boolean, default=False, nullable=False) # 启用开关
|
||||
requires_config = Column(Boolean, default=False, nullable=False) # 是否需要配置
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
|
||||
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class}, enabled={self.is_enabled})>"
|
||||
|
||||
|
||||
class CustomToolConfig(Base):
|
||||
@@ -115,7 +169,7 @@ class MCPToolConfig(Base):
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置
|
||||
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
|
||||
157
api/app/repositories/tool_repository.py
Normal file
157
api/app/repositories/tool_repository.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""工具数据访问层"""
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
from app.repositories.base_repository import BaseRepository
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus
|
||||
)
|
||||
|
||||
|
||||
class ToolRepository:
|
||||
"""工具仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def find_by_tenant(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
is_enabled: Optional[bool] = None
|
||||
) -> List[ToolConfig]:
|
||||
"""根据租户查找工具"""
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
if status:
|
||||
query = query.filter(ToolConfig.status == status.value)
|
||||
if is_enabled is not None:
|
||||
query = query.filter(ToolConfig.is_enabled == is_enabled)
|
||||
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""根据ID和租户查找工具"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||||
"""统计租户工具数量"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).count()
|
||||
|
||||
@staticmethod
|
||||
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||
"""获取状态统计"""
|
||||
return db.query(
|
||||
ToolConfig.status,
|
||||
func.count(ToolConfig.id).label('count')
|
||||
).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).group_by(ToolConfig.status).all()
|
||||
|
||||
@staticmethod
|
||||
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||
"""获取类型统计"""
|
||||
return db.query(
|
||||
ToolConfig.tool_type,
|
||||
func.count(ToolConfig.id).label('count')
|
||||
).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).group_by(ToolConfig.tool_type).all()
|
||||
|
||||
@staticmethod
|
||||
def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||||
"""统计租户启用的工具数量"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_enabled == True
|
||||
).count()
|
||||
|
||||
@staticmethod
|
||||
def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool:
|
||||
"""检查租户是否已有内置工具"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN.value
|
||||
).count() > 0
|
||||
|
||||
|
||||
class BuiltinToolRepository:
|
||||
"""内置工具仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]:
|
||||
"""根据工具ID查找内置工具配置"""
|
||||
return db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_id
|
||||
).first()
|
||||
|
||||
|
||||
class CustomToolRepository:
|
||||
"""自定义工具仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]:
|
||||
"""根据工具ID查找自定义工具配置"""
|
||||
return db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_id
|
||||
).first()
|
||||
|
||||
|
||||
class MCPToolRepository:
|
||||
"""MCP工具仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]:
|
||||
"""根据工具ID查找MCP工具配置"""
|
||||
return db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_id
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def find_error_connections(db: Session) -> List[MCPToolConfig]:
|
||||
"""查找连接错误的MCP工具"""
|
||||
return db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.connection_status == "error"
|
||||
).all()
|
||||
|
||||
|
||||
class ToolExecutionRepository:
|
||||
"""工具执行仓储类"""
|
||||
|
||||
@staticmethod
|
||||
def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]:
|
||||
"""根据执行ID查找执行记录"""
|
||||
return db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def find_by_tool_and_tenant(
|
||||
db: Session,
|
||||
tool_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 100
|
||||
) -> List[ToolExecution]:
|
||||
"""根据工具和租户查找执行记录"""
|
||||
return db.query(ToolExecution).join(
|
||||
ToolConfig, ToolExecution.tool_config_id == ToolConfig.id
|
||||
).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).order_by(ToolExecution.started_at.desc()).limit(limit).all()
|
||||
259
api/app/schemas/tool_schema.py
Normal file
259
api/app/schemas/tool_schema.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""工具相关的数据模式定义"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.core.api_key_utils import datetime_to_timestamp
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型枚举"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""工具参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: ParameterType = Field(..., description="参数类型")
|
||||
description: str = Field("", description="参数描述")
|
||||
required: bool = Field(False, description="是否必需")
|
||||
default: Any = Field(None, description="默认值")
|
||||
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
||||
minimum: Optional[float] = Field(None, description="最小值")
|
||||
maximum: Optional[float] = Field(None, description="最大值")
|
||||
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果"""
|
||||
success: bool = Field(..., description="执行是否成功")
|
||||
data: Any = Field(None, description="返回数据")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
error_code: Optional[str] = Field(None, description="错误代码")
|
||||
execution_time: float = Field(..., description="执行时间(秒)")
|
||||
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
data: Any,
|
||||
execution_time: float,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建成功结果"""
|
||||
return cls(
|
||||
success=True,
|
||||
data=data,
|
||||
execution_time=execution_time,
|
||||
token_usage=token_usage,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error_result(
|
||||
cls,
|
||||
error: str,
|
||||
execution_time: float,
|
||||
error_code: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建错误结果"""
|
||||
return cls(
|
||||
success=False,
|
||||
error=error,
|
||||
error_code=error_code,
|
||||
execution_time=execution_time,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""工具信息"""
|
||||
id: str = Field(..., description="工具ID")
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
icon: Optional[str] = Field(None, description="工具图标")
|
||||
tool_type: ToolType = Field(..., description="工具类型")
|
||||
version: str = Field("1.0.0", description="工具版本")
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
|
||||
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
@field_serializer('created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v):
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
|
||||
class ToolConfigSchema(BaseModel):
|
||||
"""工具配置基础模式"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
tool_type: ToolType
|
||||
status: ToolStatus
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
version: str = "1.0.0"
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
tenant_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BuiltinToolConfigSchema(BaseModel):
|
||||
"""内置工具配置模式"""
|
||||
tool_class: str
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict)
|
||||
is_enabled: bool
|
||||
requires_config: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CustomToolConfigSchema(BaseModel):
|
||||
"""自定义工具配置模式"""
|
||||
base_url: Optional[str] = None
|
||||
auth_type: AuthType = AuthType.NONE
|
||||
auth_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
timeout: int = 30
|
||||
schema_content: Optional[Dict[str, Any]] = None
|
||||
schema_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MCPToolConfigSchema(BaseModel):
|
||||
"""MCP工具配置模式"""
|
||||
server_url: str
|
||||
connection_config: Dict[str, Any] = Field(default_factory=dict)
|
||||
last_health_check: Optional[datetime] = None
|
||||
health_status: str = "unknown"
|
||||
error_message: Optional[str] = None
|
||||
available_tools: List[str] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolDetailSchema(ToolConfigSchema):
|
||||
"""工具详情模式(包含类型特定配置)"""
|
||||
builtin_config: Optional[BuiltinToolConfigSchema] = None
|
||||
custom_config: Optional[CustomToolConfigSchema] = None
|
||||
mcp_config: Optional[MCPToolConfigSchema] = None
|
||||
|
||||
|
||||
class ToolExecutionSchema(BaseModel):
|
||||
"""工具执行记录模式"""
|
||||
id: str
|
||||
execution_id: str
|
||||
status: str
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
output_data: Optional[Dict[str, Any]] = None
|
||||
error_message: Optional[str] = None
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
execution_time: Optional[float] = None
|
||||
token_usage: Optional[Dict[str, int]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolCreateRequest(BaseModel):
|
||||
"""创建工具请求"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
icon: Optional[str] = Field(None, max_length=255)
|
||||
tool_type: ToolType
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolUpdateRequest(BaseModel):
|
||||
"""更新工具请求"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
icon: Optional[str] = Field(None, max_length=255)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class ToolExecuteRequest(BaseModel):
|
||||
"""执行工具请求"""
|
||||
tool_id: str
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict)
|
||||
timeout: Optional[float] = Field(60.0, gt=0, le=300)
|
||||
|
||||
|
||||
class CustomToolCreateRequest(BaseModel):
|
||||
"""创建自定义工具请求"""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
icon: Optional[str] = Field(None, max_length=255)
|
||||
auth_type: AuthType = Field(AuthType.NONE, description="认证类型")
|
||||
auth_config: Dict[str, Any] = Field(default_factory=dict, description="认证配置")
|
||||
timeout: int = Field(30, ge=1, le=300, description="超时时间")
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容")
|
||||
schema_url: Optional[str] = Field(None, description="OpenAPI schema URL")
|
||||
|
||||
|
||||
class ParseSchemaRequest(BaseModel):
|
||||
"""解析Schema请求"""
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容")
|
||||
schema_url: Optional[str] = Field(None, description="OpenAPI schema URL")
|
||||
|
||||
|
||||
class ToolListQuery(BaseModel):
|
||||
"""工具列表查询参数"""
|
||||
name: Optional[str] = None
|
||||
tool_type: Optional[ToolType] = None
|
||||
status: Optional[ToolStatus] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
page: int = Field(1, ge=1)
|
||||
page_size: int = Field(20, ge=1, le=100)
|
||||
|
||||
|
||||
class ToolStatusCount(BaseModel):
|
||||
"""工具状态统计"""
|
||||
status: ToolStatus
|
||||
count: int
|
||||
|
||||
|
||||
class ToolStatistics(BaseModel):
|
||||
"""工具统计信息"""
|
||||
total_tools: int
|
||||
status_counts: List[ToolStatusCount]
|
||||
type_counts: Dict[str, int]
|
||||
enabled_count: int
|
||||
disabled_count: int
|
||||
|
||||
|
||||
class CustomToolTestRequest(BaseModel):
|
||||
"""自定义工具测试请求"""
|
||||
method: str = Field(..., description="HTTP方法")
|
||||
path: str = Field(..., description="API路径")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="请求参数")
|
||||
977
api/app/services/tool_service.py
Normal file
977
api/app/services/tool_service.py
Normal file
@@ -0,0 +1,977 @@
|
||||
"""工具服务 - 统一的工具管理和执行服务"""
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
import importlib
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.tools.mcp import MCPClient
|
||||
from app.repositories.tool_repository import (
|
||||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||||
MCPToolRepository, ToolExecutionRepository
|
||||
)
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus, ExecutionStatus, AuthType
|
||||
)
|
||||
from app.schemas.tool_schema import ToolInfo, ToolResult
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.core.tools.custom.base import CustomTool
|
||||
from app.core.tools.mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 内置工具映射
|
||||
BUILTIN_TOOLS = {
|
||||
"DateTimeTool": "app.core.tools.builtin.datetime_tool",
|
||||
"JsonTool": "app.core.tools.builtin.json_tool",
|
||||
"BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool",
|
||||
"MinerUTool": "app.core.tools.builtin.mineru_tool",
|
||||
"TextInTool": "app.core.tools.builtin.textin_tool"
|
||||
}
|
||||
|
||||
|
||||
class ToolService:
|
||||
"""统一工具服务 - 管理工具的完整生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, BaseTool] = {}
|
||||
|
||||
# 初始化仓储
|
||||
self.tool_repo = ToolRepository()
|
||||
self.builtin_repo = BuiltinToolRepository()
|
||||
self.custom_repo = CustomToolRepository()
|
||||
self.mcp_repo = MCPToolRepository()
|
||||
self.execution_repo = ToolExecutionRepository()
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""获取工具列表"""
|
||||
try:
|
||||
configs = self.tool_repo.find_by_tenant(
|
||||
db=self.db,
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
tool_type=tool_type,
|
||||
status=status
|
||||
)
|
||||
return [self._config_to_info(config) for config in configs]
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
|
||||
"""获取工具详情"""
|
||||
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
return self._config_to_info(config) if config else None
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
name: str,
|
||||
tool_type: ToolType,
|
||||
tenant_id: uuid.UUID,
|
||||
icon: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""创建工具"""
|
||||
if tool_type == ToolType.BUILTIN:
|
||||
raise ValueError("内置工具不允许创建")
|
||||
|
||||
try:
|
||||
# 创建基础配置
|
||||
tool_config = ToolConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
icon=icon,
|
||||
tool_type=tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
status=ToolStatus.AVAILABLE.value,
|
||||
config_data=config or {}
|
||||
)
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建类型特定配置
|
||||
self._create_type_config(tool_config, config or {})
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"工具创建成功: {tool_config.id}")
|
||||
return str(tool_config.id)
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"创建工具失败: {e}")
|
||||
raise
|
||||
|
||||
def update_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
tenant_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
icon: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
is_enabled: Optional[bool] = None
|
||||
) -> bool:
|
||||
"""更新工具"""
|
||||
config_obj = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config_obj:
|
||||
return False
|
||||
|
||||
if config_obj.tool_type == ToolType.BUILTIN.value:
|
||||
if name or description or icon:
|
||||
raise ValueError("内置工具不允许修改名称、描述和图标")
|
||||
try:
|
||||
if name:
|
||||
config_obj.name = name
|
||||
if description:
|
||||
config_obj.description = description
|
||||
if icon:
|
||||
config_obj.icon = icon
|
||||
if config:
|
||||
config_obj.config_data = config.copy()
|
||||
|
||||
# 同步到类型表
|
||||
self._sync_type_config(config_obj, config, is_enabled)
|
||||
|
||||
# 更新状态逻辑
|
||||
self._update_tool_status(config_obj)
|
||||
|
||||
# 清除缓存
|
||||
self._clear_tool_cache(tool_id)
|
||||
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新工具失败: {tool_id}, {e}")
|
||||
return False
|
||||
|
||||
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除工具"""
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
if config.tool_type == ToolType.BUILTIN.value:
|
||||
raise ValueError("内置工具不允许删除")
|
||||
|
||||
try:
|
||||
# 删除关联表记录
|
||||
if config.tool_type == ToolType.CUSTOM.value:
|
||||
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
|
||||
elif config.tool_type == ToolType.MCP.value:
|
||||
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
|
||||
|
||||
# 删除主表记录(ToolExecution会通过cascade自动删除)
|
||||
self.db.delete(config)
|
||||
self._clear_tool_cache(tool_id)
|
||||
self.db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"删除工具失败: {tool_id}, {e}")
|
||||
return False
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
timeout: float = 60.0
|
||||
) -> ToolResult:
|
||||
"""执行工具"""
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self._get_tool_instance(tool_id, tenant_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# 记录执行开始
|
||||
self._record_execution_start(
|
||||
execution_id, tool_id, parameters, user_id, workspace_id
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
result = await tool.safe_execute(**parameters)
|
||||
|
||||
# 记录执行完成
|
||||
self._record_execution_complete(execution_id, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
error_result = ToolResult.error_result(
|
||||
error=str(e),
|
||||
execution_time=execution_time
|
||||
)
|
||||
self._record_execution_complete(execution_id, error_result)
|
||||
return error_result
|
||||
|
||||
async def test_connection(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config:
|
||||
return {"success": False, "message": "工具不存在"}
|
||||
|
||||
if config.tool_type == ToolType.MCP.value:
|
||||
return await self._test_mcp_connection(config)
|
||||
elif config.tool_type == ToolType.CUSTOM.value:
|
||||
return await self._test_custom_connection(config)
|
||||
elif config.tool_type == ToolType.BUILTIN.value:
|
||||
return await self._test_builtin_connection(config)
|
||||
else:
|
||||
return {"success": True, "message": "未知工具类型"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
|
||||
def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID):
|
||||
"""确保内置工具已初始化"""
|
||||
existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id)
|
||||
|
||||
if existing:
|
||||
return
|
||||
|
||||
# 从配置文件加载内置工具定义
|
||||
builtin_config = self._load_builtin_config()
|
||||
|
||||
for tool_key, tool_info in builtin_config.items():
|
||||
try:
|
||||
# 创建工具配置
|
||||
initial_status = self._determine_initial_status(tool_info)
|
||||
tool_config = ToolConfig(
|
||||
name=tool_info['name'],
|
||||
description=tool_info['description'],
|
||||
tool_type=ToolType.BUILTIN.value,
|
||||
tenant_id=tenant_id,
|
||||
status=initial_status,
|
||||
config_data={"tool_class": tool_info['tool_class'],
|
||||
"requires_config": tool_info.get('requires_config', False),
|
||||
"is_enabled": False},
|
||||
version=tool_info["version"]
|
||||
)
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建内置工具配置
|
||||
builtin_config_obj = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool_info['tool_class'],
|
||||
parameters={},
|
||||
requires_config=tool_info.get('requires_config', False)
|
||||
)
|
||||
self.db.add(builtin_config_obj)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化内置工具失败: {tool_key}, {e}")
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"租户 {tenant_id} 内置工具初始化完成")
|
||||
|
||||
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||||
"""获取工具统计信息"""
|
||||
try:
|
||||
# 总数统计
|
||||
total_tools = self.tool_repo.count_by_tenant(self.db, tenant_id)
|
||||
|
||||
# 状态统计
|
||||
status_counts = self.tool_repo.get_status_statistics(self.db, tenant_id)
|
||||
|
||||
# 类型统计
|
||||
type_counts = self.tool_repo.get_type_statistics(self.db, tenant_id)
|
||||
|
||||
# 启用/禁用统计
|
||||
enabled_count = self.tool_repo.count_enabled_by_tenant(self.db, tenant_id)
|
||||
disabled_count = total_tools - enabled_count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"status_counts": [
|
||||
{"status": status, "count": count}
|
||||
for status, count in status_counts
|
||||
],
|
||||
"type_counts": {
|
||||
tool_type: count for tool_type, count in type_counts
|
||||
},
|
||||
"enabled_count": enabled_count,
|
||||
"disabled_count": disabled_count
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败: {e}")
|
||||
return {
|
||||
"total_tools": 0,
|
||||
"status_counts": [],
|
||||
"type_counts": {},
|
||||
"enabled_count": 0,
|
||||
"disabled_count": 0
|
||||
}
|
||||
|
||||
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""获取工具配置"""
|
||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
|
||||
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
"""获取工具实例"""
|
||||
if tool_id in self._tool_cache:
|
||||
return self._tool_cache[tool_id]
|
||||
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
try:
|
||||
tool = self._create_tool_instance(config)
|
||||
if tool:
|
||||
self._tool_cache[tool_id] = tool
|
||||
return tool
|
||||
except Exception as e:
|
||||
logger.error(f"创建工具实例失败: {tool_id}, {e}")
|
||||
return None
|
||||
|
||||
def _create_tool_instance(self, config: ToolConfig) -> Optional[BaseTool]:
|
||||
"""创建工具实例"""
|
||||
if config.tool_type == ToolType.BUILTIN.value:
|
||||
return self._create_builtin_instance(config)
|
||||
elif config.tool_type == ToolType.CUSTOM.value:
|
||||
return self._create_custom_instance(config)
|
||||
elif config.tool_type == ToolType.MCP.value:
|
||||
return self._create_mcp_instance(config)
|
||||
return None
|
||||
|
||||
def _create_builtin_instance(self, config: ToolConfig) -> Optional[BaseTool]:
|
||||
"""创建内置工具实例"""
|
||||
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||||
|
||||
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
|
||||
return None
|
||||
|
||||
try:
|
||||
module_path = BUILTIN_TOOLS[builtin_config.tool_class]
|
||||
module = importlib.import_module(module_path)
|
||||
tool_class = getattr(module, builtin_config.tool_class)
|
||||
|
||||
tool_config = {
|
||||
**config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
}
|
||||
|
||||
return tool_class(str(config.id), tool_config)
|
||||
except Exception as e:
|
||||
logger.error(f"创建内置工具实例失败: {builtin_config.tool_class}, {e}")
|
||||
return None
|
||||
|
||||
def _create_custom_instance(self, config: ToolConfig) -> Optional[CustomTool]:
|
||||
"""创建自定义工具实例"""
|
||||
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
|
||||
|
||||
if not custom_config:
|
||||
return None
|
||||
|
||||
tool_config = {
|
||||
"base_url": custom_config.base_url,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config or {},
|
||||
"timeout": custom_config.timeout or 30,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"schema_url": custom_config.schema_url
|
||||
}
|
||||
|
||||
return CustomTool(str(config.id), tool_config)
|
||||
|
||||
def _create_mcp_instance(self, config: ToolConfig) -> Optional[MCPTool]:
|
||||
"""创建MCP工具实例"""
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
|
||||
if not mcp_config:
|
||||
return None
|
||||
|
||||
tool_config = {
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"available_tools": mcp_config.available_tools or []
|
||||
}
|
||||
|
||||
return MCPTool(str(config.id), tool_config)
|
||||
|
||||
def _config_to_info(self, config: ToolConfig) -> ToolInfo:
|
||||
"""配置转换为信息对象"""
|
||||
config_data = config.config_data or {}
|
||||
|
||||
# 对于MCP工具,从MCPToolConfig获取额外信息
|
||||
if config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if mcp_config:
|
||||
config_data.update({
|
||||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||||
"health_status": mcp_config.health_status,
|
||||
"available_tools": mcp_config.available_tools or []
|
||||
})
|
||||
|
||||
return ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
icon=config.icon,
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version or "1.0.0",
|
||||
status=ToolStatus(config.status),
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None,
|
||||
config_data=config_data,
|
||||
created_at=config.created_at
|
||||
)
|
||||
|
||||
def _create_type_config(self, tool_config: ToolConfig, config: Dict[str, Any]):
|
||||
"""创建类型特定配置"""
|
||||
if tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 从 schema 中解析 base_url
|
||||
base_url = config.get("base_url")
|
||||
if not base_url and (config.get("schema_content") or config.get("schema_url")):
|
||||
try:
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
parser = OpenAPISchemaParser()
|
||||
|
||||
if config.get("schema_content"):
|
||||
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), "application/json")
|
||||
else:
|
||||
success, schema, _ = parser.parse_from_url(config["schema_url"])
|
||||
|
||||
if success:
|
||||
tool_info = parser.extract_tool_info(schema)
|
||||
servers = tool_info.get("servers", [])
|
||||
base_url = servers[0].get("url") if servers else ""
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema获取base_url失败: {e}")
|
||||
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
base_url=base_url,
|
||||
auth_type=config.get("auth_type", "none"),
|
||||
auth_config=config.get("auth_config", {}),
|
||||
timeout=config.get("timeout", 30),
|
||||
schema_content=config.get("schema_content"),
|
||||
schema_url=config.get("schema_url")
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=config.get("server_url"),
|
||||
connection_config=config.get("connection_config", {}),
|
||||
available_tools=config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
def _sync_type_config(self, tool_config: ToolConfig, config: Dict[str, Any], is_enabled: bool):
|
||||
"""同步到类型特定表"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
if builtin_config:
|
||||
builtin_config.parameters = config.get("parameters", {})
|
||||
if is_enabled is not None:
|
||||
builtin_config.is_enabled = is_enabled
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
if custom_config:
|
||||
base_url = config.get("base_url")
|
||||
if not base_url and (config.get("schema_content") or config.get("schema_url")):
|
||||
try:
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
parser = OpenAPISchemaParser()
|
||||
|
||||
if config.get("schema_content"):
|
||||
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]),
|
||||
"application/json")
|
||||
else:
|
||||
success, schema, _ = parser.parse_from_url(config["schema_url"])
|
||||
|
||||
if success:
|
||||
tool_info = parser.extract_tool_info(schema)
|
||||
servers = tool_info.get("servers", [])
|
||||
base_url = servers[0].get("url") if servers else ""
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema获取base_url失败: {e}")
|
||||
custom_config.base_url = base_url
|
||||
custom_config.auth_type = config.get("auth_type", "none")
|
||||
custom_config.auth_config = config.get("auth_config", {})
|
||||
custom_config.timeout = config.get("timeout", 30)
|
||||
custom_config.schema_content = config.get("schema_content")
|
||||
custom_config.schema_url = config.get("schema_url")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
if mcp_config:
|
||||
mcp_config.server_url = config.get("server_url")
|
||||
mcp_config.connection_config = config.get("connection_config", {})
|
||||
mcp_config.available_tools = config.get("available_tools", [])
|
||||
|
||||
@staticmethod
|
||||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||||
"""确定工具初始状态"""
|
||||
if tool_info.get('requires_config', False):
|
||||
return ToolStatus.UNCONFIGURED
|
||||
else:
|
||||
return ToolStatus.AVAILABLE
|
||||
|
||||
def _update_tool_status(self, tool_config: ToolConfig):
|
||||
"""更新工具状态逻辑"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config:
|
||||
if builtin_config.requires_config:
|
||||
# 需要配置的工具
|
||||
if self._is_tool_configured(builtin_config):
|
||||
if tool_config.config_data.get("is_enabled", None):
|
||||
tool_config.status = ToolStatus.AVAILABLE.value
|
||||
else:
|
||||
tool_config.status = ToolStatus.CONFIGURED_DISABLED.value
|
||||
else:
|
||||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||||
else:
|
||||
# 不需要配置的工具
|
||||
tool_config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url):
|
||||
tool_config.status = ToolStatus.AVAILABLE.value
|
||||
else:
|
||||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
if mcp_config.health_status == "healthy":
|
||||
tool_config.status = ToolStatus.AVAILABLE.value
|
||||
elif mcp_config.health_status == "error":
|
||||
tool_config.status = ToolStatus.ERROR.value
|
||||
else:
|
||||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||||
|
||||
def _is_tool_configured(self, builtin_config: BuiltinToolConfig) -> bool:
|
||||
"""检查工具是否已配置"""
|
||||
# 从配置文件获取必需参数
|
||||
builtin_config_data = self._load_builtin_config()
|
||||
required_params = {}
|
||||
for key, value in builtin_config_data.items():
|
||||
if builtin_config.tool_class == value["tool_class"]:
|
||||
required_params = value.get('parameters', {})
|
||||
break
|
||||
|
||||
# 检查所有必需参数是否已配置
|
||||
for param_name, param_info in required_params.items():
|
||||
if param_info.get('required', False):
|
||||
if not builtin_config.parameters.get(param_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _clear_tool_cache(self, tool_id: str):
|
||||
"""清除工具缓存"""
|
||||
if tool_id in self._tool_cache:
|
||||
del self._tool_cache[tool_id]
|
||||
|
||||
def _record_execution_start(
|
||||
self,
|
||||
execution_id: str,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: Optional[uuid.UUID],
|
||||
workspace_id: Optional[uuid.UUID]
|
||||
):
|
||||
"""记录执行开始"""
|
||||
try:
|
||||
execution = ToolExecution(
|
||||
execution_id=execution_id,
|
||||
tool_config_id=uuid.UUID(tool_id),
|
||||
status=ExecutionStatus.RUNNING.value,
|
||||
input_data=parameters,
|
||||
started_at=datetime.now(),
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
self.db.add(execution)
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录执行开始失败: {execution_id}, {e}")
|
||||
|
||||
def _record_execution_complete(self, execution_id: str, result: ToolResult):
|
||||
"""记录执行完成"""
|
||||
try:
|
||||
execution = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
if execution:
|
||||
execution.status = ExecutionStatus.COMPLETED.value if result.success else ExecutionStatus.FAILED.value
|
||||
execution.output_data = result.data if result.success else None
|
||||
execution.error_message = result.error if not result.success else None
|
||||
execution.completed_at = datetime.now()
|
||||
execution.execution_time = result.execution_time
|
||||
execution.token_usage = result.token_usage
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录执行完成失败: {execution_id}, {e}")
|
||||
|
||||
@staticmethod
|
||||
def _load_builtin_config() -> Dict[str, Any]:
|
||||
"""加载内置工具配置"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
config_file = Path(__file__).parent.parent / "core" / "tools" / "configs" / "builtin_tools.json"
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == config.id
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
|
||||
# 更新连接状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
# 更新工具状态
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
"details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
|
||||
}
|
||||
except Exception as e:
|
||||
await client.disconnect()
|
||||
|
||||
# 更新错误状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
|
||||
else:
|
||||
# 更新连接失败状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = "连接失败"
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
|
||||
except Exception as e:
|
||||
# 更新异常状态
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == config.id
|
||||
).first()
|
||||
if mcp_config:
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
async def parse_openapi_schema(schema_data: Dict[str, Any] = None, schema_url: str = None) -> Dict[str, Any]:
|
||||
"""解析OpenAPI schema获取接口信息"""
|
||||
try:
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
|
||||
parser = OpenAPISchemaParser()
|
||||
|
||||
# 使用现有的解析器
|
||||
if schema_data:
|
||||
success, schema, error = parser.parse_from_content(json.dumps(schema_data), "application/json")
|
||||
elif schema_url:
|
||||
success, schema, error = await parser.parse_from_url(schema_url)
|
||||
else:
|
||||
return {"success": False, "message": "schema_data或schema_url必须提供一个"}
|
||||
|
||||
if not success:
|
||||
return {"success": False, "message": error}
|
||||
|
||||
# 提取工具信息
|
||||
tool_info = parser.extract_tool_info(schema)
|
||||
|
||||
# 获取base_url
|
||||
servers = tool_info.get("servers", [])
|
||||
base_url = servers[0].get("url") if servers else ""
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"title": tool_info["name"],
|
||||
"description": tool_info["description"],
|
||||
"version": tool_info["version"],
|
||||
"base_url": base_url,
|
||||
"operations": list(tool_info["operations"].values())
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析OpenAPI schema失败: {e}")
|
||||
return {"success": False, "message": f"解析失败: {str(e)}"}
|
||||
|
||||
async def sync_mcp_tools(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||||
"""同步MCP工具列表到数据库"""
|
||||
try:
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config or config.tool_type != ToolType.MCP.value:
|
||||
return {"success": False, "message": "工具不存在或不是MCP工具"}
|
||||
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
# 创建MCP客户端
|
||||
connection_config = mcp_config.connection_config or {}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, connection_config)
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
# 获取工具列表
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
# 更新工具状态
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
self.db.commit()
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "工具列表同步成功",
|
||||
"tools_count": len(tool_names),
|
||||
"tools": tool_names
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await client.disconnect()
|
||||
|
||||
# 更新错误状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
|
||||
else:
|
||||
# 连接失败
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = "连接失败"
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
||||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||||
|
||||
async def _test_custom_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试自定义工具连接(基础连接测试)"""
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == config.id
|
||||
).first()
|
||||
|
||||
if not custom_config or not custom_config.base_url:
|
||||
return {"success": False, "message": "自定义工具配置不完整"}
|
||||
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
custom_config.base_url,
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return {"success": True, "message": "自定义工具连接成功"}
|
||||
else:
|
||||
return {"success": False, "message": f"连接失败,状态码: {response.status}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"自定义工具测试失败: {str(e)}"}
|
||||
|
||||
async def test_custom_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
tenant_id: uuid.UUID,
|
||||
method: str,
|
||||
path: str,
|
||||
parameters: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""测试自定义工具API调用"""
|
||||
try:
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config or config.tool_type != ToolType.CUSTOM.value:
|
||||
return {"success": False, "message": "工具不存在或不是自定义工具"}
|
||||
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == config.id
|
||||
).first()
|
||||
|
||||
if not custom_config or not custom_config.base_url:
|
||||
return {"success": False, "message": "自定义工具配置不完整"}
|
||||
|
||||
# 构建完整URL
|
||||
url = custom_config.base_url.rstrip('/') + '/' + path.lstrip('/')
|
||||
|
||||
# 构建请求头
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# 添加认证头
|
||||
if custom_config.auth_type != AuthType.NONE.value:
|
||||
auth_config = custom_config.auth_config or {}
|
||||
if custom_config.auth_type == AuthType.API_KEY.value:
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
api_key = auth_config.get("api_key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
elif custom_config.auth_type == AuthType.BEARER_TOKEN.value:
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif custom_config.auth_type == AuthType.BASIC_AUTH.value:
|
||||
import base64
|
||||
username = auth_config.get("username", "")
|
||||
password = auth_config.get("password", "")
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 根据方法发送请求
|
||||
if method.upper() == "GET":
|
||||
async with session.get(
|
||||
url,
|
||||
params=parameters,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
|
||||
) as response:
|
||||
result_data = await response.text()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "测试成功",
|
||||
"status_code": response.status,
|
||||
"response_data": result_data[:1000] # 限制返回数据长度
|
||||
}
|
||||
else:
|
||||
async with session.request(
|
||||
method.upper(),
|
||||
url,
|
||||
json=parameters,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
|
||||
) as response:
|
||||
result_data = await response.text()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "测试成功",
|
||||
"status_code": response.status,
|
||||
"response_data": result_data[:1000] # 限制返回数据长度
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试自定义工具API失败: {tool_id}, 错误: {e}")
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
|
||||
async def _test_builtin_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试内置工具连接"""
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||
if not tool_instance:
|
||||
return {"success": False, "message": "无法创建工具实例"}
|
||||
|
||||
# 检查工具是否有test_connection方法
|
||||
if hasattr(tool_instance, 'test_connection'):
|
||||
result = await tool_instance.test_connection()
|
||||
return result
|
||||
else:
|
||||
# 检查是否需要配置
|
||||
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||||
if builtin_config and builtin_config.requires_config:
|
||||
# 检查必需参数是否已配置
|
||||
if self._is_tool_configured(builtin_config):
|
||||
return {"success": True, "message": "内置工具已正确配置"}
|
||||
else:
|
||||
return {"success": False, "message": "工具缺少必需配置参数"}
|
||||
else:
|
||||
return {"success": True, "message": "内置工具无需连接测试"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试内置工具失败: {config.id}, 错误: {e}")
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
@@ -1,374 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
工具管理系统基础测试脚本
|
||||
用于验证系统的基本功能是否正常
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# 测试导入
|
||||
def test_imports():
|
||||
"""测试模块导入"""
|
||||
print("测试模块导入...")
|
||||
|
||||
try:
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
|
||||
print("✓ 基础工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 基础工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
from app.core.tools.builtin.json_tool import JsonTool
|
||||
print("✓ 内置工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 内置工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
print("✓ Langchain适配器导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ Langchain适配器导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, ToolStatus
|
||||
print("✓ 工具模型导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 工具模型导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager
|
||||
print("✓ 自定义工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 自定义工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager
|
||||
print("✓ MCP工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ MCP工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_tool_creation():
|
||||
"""测试工具创建"""
|
||||
print("\n测试工具创建...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
|
||||
# 创建时间工具实例(全局工具)
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"timezone": "UTC"},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0",
|
||||
"tags": ["time", "utility", "builtin"]
|
||||
}
|
||||
|
||||
datetime_tool = DateTimeTool(tool_id, config)
|
||||
|
||||
# 验证工具属性
|
||||
assert datetime_tool.name == "datetime_tool"
|
||||
assert datetime_tool.tool_type.value == "builtin"
|
||||
assert len(datetime_tool.parameters) > 0
|
||||
|
||||
print("✓ 时间工具创建成功(全局工具)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 工具创建失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_tool_execution():
|
||||
"""测试工具执行"""
|
||||
print("\n测试工具执行...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
|
||||
# 创建时间工具实例
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"timezone": "UTC"},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
datetime_tool = DateTimeTool(tool_id, config)
|
||||
|
||||
# 测试获取当前时间
|
||||
result = await datetime_tool.safe_execute(operation="now")
|
||||
|
||||
assert result.success == True
|
||||
assert "datetime" in result.data
|
||||
assert result.execution_time > 0
|
||||
|
||||
print("✓ 工具执行成功")
|
||||
print(f" 执行时间: {result.execution_time:.3f}秒")
|
||||
print(f" 返回数据: {result.data}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 工具执行失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_langchain_adapter():
|
||||
"""测试Langchain适配器"""
|
||||
print("\n测试Langchain适配器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.json_tool import JsonTool
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
|
||||
# 创建JSON工具实例
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"indent": 2},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
json_tool = JsonTool(tool_id, config)
|
||||
|
||||
# 验证Langchain兼容性
|
||||
is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool)
|
||||
|
||||
if not is_compatible:
|
||||
print(f"✗ Langchain兼容性验证失败: {issues}")
|
||||
return False
|
||||
|
||||
# 创建工具描述
|
||||
description = LangchainAdapter.create_tool_description(json_tool)
|
||||
|
||||
assert "name" in description
|
||||
assert "parameters" in description
|
||||
assert description["langchain_compatible"] == True
|
||||
|
||||
print("✓ Langchain适配器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Langchain适配器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_config_manager():
|
||||
"""测试配置管理器"""
|
||||
print("\n测试配置管理器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
|
||||
# 创建配置管理器
|
||||
config_manager = ConfigManager()
|
||||
|
||||
# 获取配置摘要
|
||||
summary = config_manager.get_config_summary()
|
||||
|
||||
assert "config_dir" in summary
|
||||
assert "total_configs" in summary
|
||||
|
||||
print("✓ 配置管理器测试成功")
|
||||
print(f" 配置目录: {summary['config_dir']}")
|
||||
print(f" 总配置数: {summary['total_configs']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 配置管理器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_schema_parser():
|
||||
"""测试OpenAPI Schema解析器"""
|
||||
print("\n测试OpenAPI Schema解析器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
|
||||
# 创建解析器
|
||||
parser = OpenAPISchemaParser()
|
||||
|
||||
# 测试简单的OpenAPI schema
|
||||
test_schema = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
"description": "测试API"
|
||||
},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "测试接口",
|
||||
"operationId": "test_operation",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = parser.validate_schema(test_schema)
|
||||
assert is_valid, f"Schema验证失败: {error_msg}"
|
||||
|
||||
# 提取工具信息
|
||||
tool_info = parser.extract_tool_info(test_schema)
|
||||
assert tool_info["name"] == "Test API"
|
||||
assert "test_operation" in tool_info["operations"]
|
||||
|
||||
print("✓ OpenAPI Schema解析器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ OpenAPI Schema解析器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_auth_manager():
|
||||
"""测试认证管理器"""
|
||||
print("\n测试认证管理器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.custom.auth_manager import AuthManager
|
||||
from app.models.tool_model import AuthType
|
||||
|
||||
# 创建认证管理器
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# 测试API Key认证配置
|
||||
api_key_config = {
|
||||
"api_key": "test-key-123",
|
||||
"key_name": "X-API-Key",
|
||||
"location": "header"
|
||||
}
|
||||
|
||||
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config)
|
||||
assert is_valid, f"API Key配置验证失败: {error_msg}"
|
||||
|
||||
# 测试Bearer Token认证配置
|
||||
bearer_config = {
|
||||
"token": "bearer-token-123"
|
||||
}
|
||||
|
||||
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config)
|
||||
assert is_valid, f"Bearer Token配置验证失败: {error_msg}"
|
||||
|
||||
# 测试认证应用
|
||||
url = "https://api.example.com/test"
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
new_url, new_headers, new_params = auth_manager.apply_authentication(
|
||||
AuthType.API_KEY, api_key_config, url, headers, params
|
||||
)
|
||||
|
||||
assert "X-API-Key" in new_headers
|
||||
assert new_headers["X-API-Key"] == "test-key-123"
|
||||
|
||||
print("✓ 认证管理器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 认证管理器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_builtin_initializer():
|
||||
"""测试内置工具初始化器"""
|
||||
print("\n测试内置工具初始化器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin_initializer import BuiltinToolInitializer
|
||||
|
||||
# 注意:这里不能真正初始化,因为需要数据库连接
|
||||
# 只测试类的创建和基本方法
|
||||
|
||||
# 模拟数据库会话(实际使用中需要真实的数据库连接)
|
||||
class MockDB:
|
||||
def query(self, *args):
|
||||
return self
|
||||
def filter(self, *args):
|
||||
return self
|
||||
def first(self):
|
||||
return None
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
mock_db = MockDB()
|
||||
initializer = BuiltinToolInitializer(mock_db)
|
||||
|
||||
# 测试获取内置工具状态(会返回空列表,因为没有真实数据)
|
||||
status = initializer.get_builtin_tools_status()
|
||||
assert isinstance(status, list)
|
||||
|
||||
print("✓ 内置工具初始化器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 内置工具初始化器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 50)
|
||||
print("工具管理系统基础测试")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
("模块导入", test_imports),
|
||||
("工具创建", test_tool_creation),
|
||||
("工具执行", test_tool_execution),
|
||||
("Langchain适配", test_langchain_adapter),
|
||||
("配置管理", test_config_manager),
|
||||
("Schema解析器", test_schema_parser),
|
||||
("认证管理器", test_auth_manager),
|
||||
("内置工具初始化器", test_builtin_initializer)
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(test_func):
|
||||
result = await test_func()
|
||||
else:
|
||||
result = test_func()
|
||||
|
||||
if result:
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"✗ {test_name}测试异常: {e}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(f"测试结果: {passed}/{total} 通过")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 所有基础测试通过!工具管理系统基本功能正常。")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ 部分测试失败,请检查相关模块。")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user