Merge #59 into develop from feature/20251219_xjn

feat(tool system): Tool system reengineering

* feature/20251219_xjn: (2 commits)
  feat(tool system): tool system development
  feat(tool system): Tool system reengineering

Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/59
This commit is contained in:
朱文辉
2025-12-25 17:38:01 +08:00
27 changed files with 1891 additions and 3352 deletions

View File

@@ -33,7 +33,6 @@ from . import (
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
tool_execution_controller,
)
from . import user_memory_controllers
@@ -71,6 +70,5 @@ manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(tool_execution_controller.router)
__all__ = ["manager_router"]

View File

@@ -20,11 +20,10 @@ async def get_memory_info():
return success(data={}, msg="Memory API - Coming Soon")
# /v1/memory/{resource_id}/chat
@router.post("/{resource_id}/chat")
# /v1/memory/chat
@router.post("/chat")
@require_api_key(scopes=["memory"])
async def chat_with_agent_demo(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
@@ -36,13 +35,12 @@ async def chat_with_agent_demo(
scopes: 所需的权限范围列表["app", "rag", "memory"]
Args:
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
message: 请求参数
request: 声明请求
api_key_auth: 包含验证后的API Key 信息
db: db_session
"""
logger.info(f"API Key Auth: {api_key_auth}")
logger.info(f"Resource ID: {resource_id}")
logger.info(f"Resource ID: {api_key_auth.resource_id}")
logger.info(f"Message: {message}")
return success(data={"received": True}, msg="消息已接收")

View File

@@ -1,585 +1,250 @@
"""工具管理API控制器"""
import base64
from typing import List, Optional, Dict, Any
"""工具控制器 - 简化统一的工具管理接口"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Body
from langfuse.api.core import jsonable_encoder
from sqlalchemy.exc import SQLAlchemyError
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, PositiveInt, field_validator
from cryptography.fernet import Fernet
from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
)
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.core.tools.config_manager import ConfigManager
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
logger = get_business_logger()
router = APIRouter(prefix="/tools", tags=["工具管理"])
router = APIRouter(prefix="/tools", tags=["Tool System"])
# ==================== 辅助函数 ====================
def get_tool_service(db: Session = Depends(get_db)) -> ToolService:
return ToolService(db)
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""加密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
encrypted_params = {}
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
else:
encrypted_params[key] = value
return encrypted_params
@router.get("/statistics", response_model=ApiResponse)
async def get_tool_statistics(
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具统计信息"""
try:
stats = service.get_tool_statistics(current_user.tenant_id)
return success(data=stats, msg="获取统计信息成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""解密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
decrypted_params = {}
sensitive_keys = ['api_key', 'token', 'secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
try:
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
except Exception as e:
decrypted_params[key] = value
else:
decrypted_params[key] = value
return decrypted_params
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
"""更新工具状态并返回新状态"""
if tool_config.tool_type == ToolType.BUILTIN:
if not tool_info or not tool_info.get('requires_config', False):
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
elif not builtin_config or not builtin_config.parameters:
new_status = ToolStatus.INACTIVE.value
else:
# 检查是否有必要的API密钥
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
else: # 自定义和MCP工具
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
# 更新数据库中的状态
if tool_config.status != new_status:
tool_config.status = new_status
return new_status
# ==================== 请求/响应模型 ====================
class ToolListResponse(BaseModel):
"""工具列表响应"""
id: str
name: str
description: str
tool_type: str
category: str
version: str = "1.0.0"
status: str # active inactive error loading
requires_config: bool = False
# is_configured: bool = False
class Config:
from_attributes = True
class BuiltinToolConfigRequest(BaseModel):
"""内置工具配置请求"""
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
class CustomToolCreateRequest(BaseModel):
"""自定义工具创建请求体模型,包含参数校验规则"""
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
description: str = Field(None, description="工具描述")
base_url: str = Field(None, description="工具基础URL")
schema_url: str = Field(None, description="工具Schema URL")
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容可选")
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间1-300秒默认30")
# 自定义校验当auth_type为api_key时auth_config必须包含api_key字段
@field_validator("auth_config")
def validate_auth_config(cls, v, values):
auth_type = values.data.get("auth_type")
if auth_type == "api_key" and (not v or "api_key" not in v):
raise ValueError("认证类型为api_key时auth_config必须包含api_key字段")
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
raise ValueError("认证类型为bearer_token时auth_config必须包含bearer_token字段")
return v
class MCPToolCreateRequest(BaseModel):
"""MCP工具创建请求体模型适配MCP业务特性"""
# 基础必填字段(带长度/格式校验)
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
description: str = Field(None, description="MCP工具描述")
# MCP核心字段服务端URL强制HTTP/HTTPS格式
server_url: str = Field(..., description="MCP服务端URL仅支持http/https协议")
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
connection_config: Dict[str, Any] = Field({},description="MCP连接配置如认证信息、超时、重试等默认空字典")
@field_validator("connection_config")
def validate_connection_config(cls, v):
# 示例1若包含timeout必须是1-300的整数
if "timeout" in v:
timeout = v["timeout"]
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
raise ValueError("connection_config.timeout必须是1-300的整数")
return v
# @field_validator("server_url")
# def validate_server_url_protocol(cls, v):
# if v.scheme != "https":
# raise ValueError("MCP服务端URL仅支持HTTPS协议安全要求")
# return v
# ==================== API端点 ====================
@router.get("", response_model=List[ToolListResponse])
@router.get("", response_model=ApiResponse)
async def list_tools(
name: Optional[str] = None,
tool_type: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
name: Optional[str] = Query(None),
tool_type: Optional[str] = Query(None),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具列表包含内置工具、自定义工具和MCP工具"""
"""获取工具列表"""
try:
# 初始化内置工具(如果需要)
config_manager = ConfigManager()
config_manager.ensure_builtin_tools_initialized(
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
# 确保内置工具已初始化
service.ensure_builtin_tools_initialized(current_user.tenant_id)
# 获取工具列表
tools = service.list_tools(
tenant_id=current_user.tenant_id,
name=name,
tool_type=ToolType(tool_type) if tool_type else None,
status=ToolStatus(status) if status else None
)
return success(data=tools, msg="获取工具列表成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
response_tools = []
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id
@router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具详情"""
tool = service.get_tool_info(tool_id, current_user.tenant_id)
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=tool, msg="获取工具详情成功")
@router.post("", response_model=ApiResponse)
async def create_tool(
request: ToolCreateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""创建工具"""
try:
tool_id = service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
icon=request.icon,
description=request.description,
config=request.config
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
tools = query.all()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
for tool_config in tools:
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
status = _update_tool_status(tool_config, builtin_config, tool_info)
else:
status = _update_tool_status(tool_config)
response_tools.append(ToolListResponse(
id=str(tool_config.id),
name=tool_config.name,
description=tool_config.description,
tool_type=tool_config.tool_type,
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
version="1.0.0",
status=status,
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
))
return response_tools
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}")
async def get_builtin_tool_detail(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
@router.put("/{tool_id}", response_model=ApiResponse)
async def update_tool(
tool_id: str,
request: ToolUpdateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取内置工具详情"""
"""更新工具"""
try:
config_manager = ConfigManager()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id
).first()
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
is_configured = False
config_parameters = {}
if builtin_config and builtin_config.parameters:
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
# 不返回敏感信息,只返回非敏感配置
config_parameters = {k: v for k, v in builtin_config.parameters.items()
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
return {
"id": tool_config.id,
"name": tool_config.name,
"description": tool_config.description,
"category": tool_info['category'],
"status": tool_config.tool_type,
"requires_config": tool_info['requires_config'],
"is_configured": is_configured,
"config_parameters": config_parameters
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/builtin/{tool_id}/configure")
async def configure_builtin_tool(
tool_id: str,
request: BuiltinToolConfigRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""配置内置工具参数(租户级别)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 获取全局工具信息
config_manager = ConfigManager()
builtin_tools_config = config_manager.load_builtin_tools_config()
tool_info = None
for tool_key, info in builtin_tools_config.items():
if info['tool_class'] == builtin_config.tool_class:
tool_info = info
break
if not tool_info:
raise HTTPException(status_code=404, detail="工具信息不存在")
# 加密敏感参数
encrypted_params = _encrypt_sensitive_params(request.parameters)
# 更新配置
builtin_config.parameters = encrypted_params
# 更新状态
_update_tool_status(tool_config, builtin_config, tool_info)
db.commit()
return {
"success": True,
"message": f"工具 {tool_config.name} 配置成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"配置内置工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}/config")
async def get_builtin_tool_config(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具配置(用于使用)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 解密参数
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
return {
"tool_id": tool_id,
"tool_class": builtin_config.tool_class,
"name": tool_config.name,
"parameters": decrypted_params,
"status": tool_config.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具配置失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/custom")
async def create_custom_tool(
request: CustomToolCreateRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建自定义工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "custom"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
tool_config = ToolConfig(
success_flag = service.update_tool(
tool_id=tool_id,
tenant_id=current_user.tenant_id,
name=request.name,
description=request.description,
tool_type=ToolType.CUSTOM,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
icon=request.icon,
config=request.config,
is_enabled=request.config.get("is_enabled", None)
)
db.add(tool_config)
db.flush()
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具更新成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 创建CustomToolConfig记录
custom_config = CustomToolConfig(
id=tool_config.id,
base_url=request.base_url,
schema_url=request.schema_url,
schema_content=request.schema_content,
auth_type=request.auth_type,
auth_config=request.auth_config,
@router.delete("/{tool_id}", response_model=ApiResponse)
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""删除工具"""
try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具删除成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool(
request: ToolExecuteRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""执行工具"""
try:
result = await service.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
timeout=request.timeout
)
db.add(custom_config)
db.commit()
return {
"success": True,
"message": f"自定义工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
return success(
data={
"success": result.success,
"data": result.data,
"error": result.error,
"execution_time": result.execution_time,
"token_usage": result.token_usage
},
msg="工具执行完成"
)
except Exception as e:
logger.error(f"创建自定义工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mcp")
async def create_mcp_tool(
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
@router.post("/parse_schema", response_model=ApiResponse)
async def parse_openapi_schema(
request: ParseSchemaRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
service: ToolService = Depends(get_tool_service)
):
"""创建MCP工具"""
"""解析OpenAPI schema"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "mcp"
result = await service.parse_openapi_schema(request.schema_content, request.schema_url)
if result["success"] is False:
raise HTTPException(status_code=400, detail=result["message"])
return success(data=result, msg="Schema解析完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse)
async def sync_mcp_tools(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if result["success"] is False:
raise HTTPException(status_code=404, detail=result["message"])
return success(data=result, msg="MCP工具列表同步完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 创建数据库记录
try:
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.MCP,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
@router.post("/{tool_id}/test", response_model=ApiResponse)
async def test_tool_connection(
tool_id: str,
test_request: Optional[CustomToolTestRequest] = None,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""测试工具连接"""
try:
if test_request:
# 自定义工具测试
result = await service.test_custom_tool(
tool_id, current_user.tenant_id,
test_request.method, test_request.path, test_request.parameters
)
db.add(tool_config)
db.flush()
# 创建MCPToolConfig记录
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=request.server_url,
connection_config=request.connection_config
)
db.add(mcp_config)
db.commit()
except SQLAlchemyError as db_e:
db.rollback()
logger.error(f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id},工具名:{request.name}: {str(db_e)}",
exc_info=True)
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id}"
f"工具名:{request.name}{str(db_e)}")
return {
"success": True,
"message": f"MCP工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建MCP工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""删除工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
return {
"success": True,
"message": f"工具 {tool.name} 删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{tool_id}")
async def update_tool(
tool_id: str,
config_data: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""更新工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许修改")
if config_data is not None:
tool.config_data = config_data
# 更新状态
_update_tool_status(tool)
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 更新成功",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/toggle")
async def toggle_tool_status(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""切换工具活跃/非活跃状态"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 在active和inactive之间切换
if tool.status == ToolStatus.ACTIVE.value:
tool.status = ToolStatus.INACTIVE.value
elif tool.status == ToolStatus.INACTIVE.value:
tool.status = ToolStatus.ACTIVE.value
else:
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
"status": tool.status
}
except HTTPException:
raise
# 普通连接测试
result = await service.test_connection(tool_id, current_user.tenant_id)
return success(data=result, msg="连接测试完成")
except Exception as e:
logger.error(f"切换工具状态失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get("/enums/tool_types", response_model=ApiResponse)
async def get_tool_types():
"""获取工具类型枚举"""
return success(
data=[
{"value": ToolType.BUILTIN.value, "label": "内置工具"},
{"value": ToolType.CUSTOM.value, "label": "自定义工具"},
{"value": ToolType.MCP.value, "label": "MCP工具"}
],
msg="获取工具类型成功"
)
@router.get("/enums/status", response_model=ApiResponse)
async def get_tool_status():
"""获取工具状态枚举"""
return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功")
@router.get("/auth/types", response_model=ApiResponse)
async def get_auth_types():
"""获取认证类型枚举"""
return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功")

View File

@@ -1,430 +0,0 @@
"""工具执行API控制器"""
import uuid
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
from app.core.tools.builtin import *
from app.core.logging_config import get_business_logger
logger = get_business_logger()
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
# ==================== 请求/响应模型 ====================
class ToolExecutionRequest(BaseModel):
"""工具执行请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
class BatchExecutionRequest(BaseModel):
"""批量执行请求"""
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
class ToolExecutionResponse(BaseModel):
"""工具执行响应"""
success: bool
execution_id: str
tool_id: str
data: Any = None
error: Optional[str] = None
error_code: Optional[str] = None
execution_time: float
token_usage: Optional[Dict[str, int]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChainStepRequest(BaseModel):
"""链步骤请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
condition: Optional[str] = Field(None, description="执行条件")
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
error_handling: str = Field("stop", description="错误处理策略")
class ChainExecutionRequest(BaseModel):
"""链执行请求"""
name: str = Field(..., description="链名称")
description: str = Field("", description="链描述")
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
execution_mode: str = Field("sequential", description="执行模式")
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
global_timeout: Optional[float] = Field(None, description="全局超时")
class ExecutionHistoryResponse(BaseModel):
"""执行历史响应"""
execution_id: str
tool_id: str
status: str
started_at: Optional[str]
completed_at: Optional[str]
execution_time: Optional[float]
user_id: Optional[str]
workspace_id: Optional[str]
input_data: Optional[Dict[str, Any]]
output_data: Optional[Any]
error_message: Optional[str]
token_usage: Optional[Dict[str, int]]
class ToolConnectionTestResponse(BaseModel):
"""工具连接测试响应"""
success: bool
message: str
error: Optional[str] = None
details: Optional[Dict[str, Any]] = None
# ==================== 依赖注入 ====================
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
"""获取工具注册表"""
registry = ToolRegistry(db)
# 注册内置工具类
registry.register_tool_class(DateTimeTool)
registry.register_tool_class(JsonTool)
registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool)
return registry
def get_tool_executor(
db: Session = Depends(get_db),
registry: ToolRegistry = Depends(get_tool_registry)
) -> ToolExecutor:
"""获取工具执行器"""
return ToolExecutor(db, registry)
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
"""获取链管理器"""
return ChainManager(executor)
# ==================== API端点 ====================
@router.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""执行单个工具"""
try:
# 生成执行ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 执行工具
result = await executor.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
execution_id=execution_id,
timeout=request.timeout,
metadata=request.metadata
)
return ToolExecutionResponse(
success=result.success,
execution_id=execution_id,
tool_id=request.tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
)
except Exception as e:
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=List[ToolExecutionResponse])
async def execute_tools_batch(
request: BatchExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""批量执行工具"""
try:
# 准备执行配置
execution_configs = []
execution_ids = []
for exec_request in request.executions:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution_ids.append(execution_id)
execution_configs.append({
"tool_id": exec_request.tool_id,
"parameters": exec_request.parameters,
"user_id": current_user.id,
"workspace_id": current_user.current_workspace_id,
"execution_id": execution_id,
"timeout": exec_request.timeout,
"metadata": exec_request.metadata
})
# 批量执行
results = await executor.execute_tools_batch(
execution_configs,
max_concurrency=request.max_concurrency
)
# 转换响应格式
responses = []
for i, result in enumerate(results):
responses.append(ToolExecutionResponse(
success=result.success,
execution_id=execution_ids[i],
tool_id=request.executions[i].tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
))
return responses
except Exception as e:
logger.error(f"批量执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chain", response_model=Dict[str, Any])
async def execute_tool_chain(
request: ChainExecutionRequest,
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""执行工具链"""
try:
# 转换步骤格式
steps = []
for step_request in request.steps:
step = ChainStep(
tool_id=step_request.tool_id,
parameters=step_request.parameters,
condition=step_request.condition,
output_mapping=step_request.output_mapping,
error_handling=step_request.error_handling
)
steps.append(step)
# 创建链定义
chain_definition = ChainDefinition(
name=request.name,
description=request.description,
steps=steps,
execution_mode=ChainExecutionMode(request.execution_mode),
global_timeout=request.global_timeout
)
# 注册并执行链
chain_manager.register_chain(chain_definition)
result = await chain_manager.execute_chain(
chain_name=request.name,
initial_variables=request.initial_variables
)
return result
except Exception as e:
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running", response_model=List[Dict[str, Any]])
async def get_running_executions(
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取正在运行的执行"""
try:
running_executions = executor.get_running_executions()
# 过滤当前工作空间的执行
workspace_executions = [
exec_info for exec_info in running_executions
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
]
return workspace_executions
except Exception as e:
logger.error(f"获取运行中执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
async def cancel_execution(
execution_id: str = Path(..., description="执行ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""取消工具执行"""
try:
success = await executor.cancel_execution(execution_id)
if success:
return {
"success": True,
"message": "执行已取消"
}
else:
raise HTTPException(status_code=404, detail="执行不存在或已完成")
except HTTPException:
raise
except Exception as e:
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history", response_model=List[ExecutionHistoryResponse])
async def get_execution_history(
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行历史"""
try:
history = executor.get_execution_history(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
limit=limit
)
# 转换响应格式
responses = []
for record in history:
responses.append(ExecutionHistoryResponse(
execution_id=record["execution_id"],
tool_id=record["tool_id"],
status=record["status"],
started_at=record["started_at"],
completed_at=record["completed_at"],
execution_time=record["execution_time"],
user_id=record["user_id"],
workspace_id=record["workspace_id"],
input_data=record["input_data"],
output_data=record["output_data"],
error_message=record["error_message"],
token_usage=record["token_usage"]
))
return responses
except Exception as e:
logger.error(f"获取执行历史失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/statistics", response_model=Dict[str, Any])
async def get_execution_statistics(
days: int = Query(7, ge=1, le=90, description="统计天数"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行统计"""
try:
stats = executor.get_execution_statistics(
workspace_id=current_user.current_workspace_id,
days=days
)
return {
"success": True,
"statistics": stats
}
except Exception as e:
logger.error(f"获取执行统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains/running", response_model=List[Dict[str, Any]])
async def get_running_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""获取正在运行的工具链"""
try:
running_chains = chain_manager.get_running_chains()
return running_chains
except Exception as e:
logger.error(f"获取运行中工具链失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains", response_model=List[Dict[str, Any]])
async def list_tool_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""列出工具链"""
try:
chains = chain_manager.list_chains()
return chains
except Exception as e:
logger.error(f"获取工具链列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
async def test_tool_connection(
tool_id: str = Path(..., description="工具ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""测试工具连接"""
try:
result = await executor.test_tool_connection(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id
)
return ToolConnectionTestResponse(
success=result.get("success", False),
message=result.get("message", ""),
error=result.get("error"),
details=result.get("details")
)
except Exception as e:
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
return ToolConnectionTestResponse(
success=False,
message="连接测试失败",
error=str(e)
)

View File

@@ -33,10 +33,9 @@ def require_api_key(
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
Usage:
@router.get("/app/{resource_id}/chat")
@router.get("/app/chat")
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
@@ -89,26 +88,6 @@ def require_api_key(
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
)
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_id": str(resource_id),
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_id": str(resource_id),
"bound_resource_id": str(api_key_obj.resource_id)
}
)
kwargs["api_key_auth"] = ApiKeyAuth(
api_key_id=api_key_obj.id,
workspace_id=api_key_obj.workspace_id,

View File

@@ -1,11 +1,7 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .registry import ToolRegistry
from .executor import ToolExecutor
from .langchain_adapter import LangchainAdapter
from .config_manager import ConfigManager
from .chain_manager import ChainManager
# 可选导入,避免导入错误
try:
@@ -22,11 +18,7 @@ __all__ = [
"BaseTool",
"ToolResult",
"ToolParameter",
"ToolRegistry",
"ToolExecutor",
"LangchainAdapter",
"ConfigManager",
"ChainManager"
"LangchainAdapter"
]
# 只有在成功导入时才添加到__all__

View File

@@ -1,98 +1,10 @@
"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from typing import Any, Dict, List, Optional
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
from app.schemas.tool_schema import ToolParameter, ParameterType, ToolResult
class BaseTool(ABC):
@@ -107,7 +19,7 @@ class BaseTool(ABC):
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
self._status = ToolStatus.AVAILABLE
@property
@abstractmethod
@@ -153,20 +65,6 @@ class BaseTool(ABC):
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数

View File

@@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
from typing import Dict, Any, List
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolResult, ToolParameter
class BuiltinTool(BaseTool, ABC):

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timezone, timedelta
from typing import List
import pytz
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
@@ -54,14 +54,14 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="UTC"
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="UTC"
default="Asia/Shanghai"
),
ToolParameter(
name="calculation",
@@ -106,10 +106,11 @@ class DateTimeTool(BuiltinTool):
error_code="DATETIME_ERROR",
execution_time=execution_time
)
def _get_current_time(self, kwargs) -> dict:
@staticmethod
def _get_current_time(kwargs) -> dict:
"""获取当前时间"""
timezone_str = kwargs.get("to_timezone", "UTC")
timezone_str = kwargs.get("to_timezone", "Asia/Shanghai")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if timezone_str == "UTC":
@@ -118,15 +119,20 @@ class DateTimeTool(BuiltinTool):
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
utc_now = datetime.now(timezone.utc)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat()
"iso_format": now.isoformat(),
"timestamp_ms": int(now.timestamp() * 1000),
"utc_datetime": utc_now.strftime(output_format)
}
def _format_datetime(self, kwargs) -> dict:
@staticmethod
def _format_datetime(kwargs) -> dict:
"""格式化时间"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -144,8 +150,9 @@ class DateTimeTool(BuiltinTool):
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _convert_timezone(self, kwargs) -> dict:
@staticmethod
def _convert_timezone(kwargs) -> dict:
"""时区转换"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -184,8 +191,9 @@ class DateTimeTool(BuiltinTool):
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp())
}
def _timestamp_to_datetime(self, kwargs) -> dict:
@staticmethod
def _timestamp_to_datetime(kwargs) -> dict:
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
@@ -196,6 +204,8 @@ class DateTimeTool(BuiltinTool):
# 转换时间戳
timestamp = float(input_value)
if timestamp > 1e12:
timestamp = timestamp / 1000
# 设置时区
if timezone_str == "UTC":
@@ -211,8 +221,9 @@ class DateTimeTool(BuiltinTool):
"timezone": timezone_str,
"iso_format": dt.isoformat()
}
def _datetime_to_timestamp(self, kwargs) -> dict:
@staticmethod
def _datetime_to_timestamp(kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -240,7 +251,7 @@ class DateTimeTool(BuiltinTool):
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
"""时间计算"""
input_value = kwargs.get("input_value")
@@ -278,8 +289,9 @@ class DateTimeTool(BuiltinTool):
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp())
}
def _parse_time_delta(self, calculation: str) -> timedelta:
@staticmethod
def _parse_time_delta(calculation: str) -> timedelta:
"""解析时间计算表达式"""
import re

View File

@@ -121,8 +121,9 @@ class JsonTool(BuiltinTool):
error_code="JSON_ERROR",
execution_time=execution_time
)
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _format_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""格式化JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
@@ -151,12 +152,13 @@ class JsonTool(BuiltinTool):
"sort_keys": sort_keys
}
}
def _minify_json(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _minify_json(input_data: str) -> Dict[str, Any]:
"""压缩JSON"""
# 解析并压缩
data = json.loads(input_data)
minified = json.dumps(data, separators=(',', ':'))
minified = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
return {
"original_size": len(input_data),
@@ -165,7 +167,7 @@ class JsonTool(BuiltinTool):
"minified_json": minified,
"is_valid": True
}
def _validate_json(self, input_data: str) -> Dict[str, Any]:
"""验证JSON"""
try:
@@ -190,17 +192,19 @@ class JsonTool(BuiltinTool):
"size": len(input_data)
}
def _convert_json(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _convert_json(input_data: str) -> Dict[str, Any]:
"""JSON转义"""
data = json.loads(input_data)
converted = json.dumps(data, ensure_ascii=False)
converted = json.dumps(data, ensure_ascii=True, separators=(',', ':'))
return {
"converted_json": converted,
"is_valid": True
}
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _json_to_yaml(input_data: str) -> Dict[str, Any]:
"""JSON转YAML"""
data = json.loads(input_data)
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
@@ -212,8 +216,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(yaml_output),
"converted_data": yaml_output
}
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _yaml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""YAML转JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
@@ -228,10 +233,11 @@ class JsonTool(BuiltinTool):
"converted_size": len(json_output),
"converted_data": json_output
}
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _json_to_xml(input_data: str) -> Dict[str, Any]:
"""JSON转XML"""
data = json.loads(input_data)
json_data = json.loads(input_data)
def dict_to_xml(data, root_name="root"):
"""递归转换字典为XML"""
@@ -267,7 +273,7 @@ class JsonTool(BuiltinTool):
root.text = str(data)
return root
xml_element = dict_to_xml(data)
xml_element = dict_to_xml(json_data)
xml_string = ET.tostring(xml_element, encoding='unicode')
# 格式化XML
@@ -284,8 +290,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(formatted_xml),
"converted_data": formatted_xml
}
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _xml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""XML转JSON"""
indent = kwargs.get("indent", 2)
@@ -328,8 +335,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(json_output),
"converted_data": json_output
}
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _merge_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""合并JSON"""
merge_data = kwargs.get("merge_data")
if not merge_data:
@@ -364,8 +372,9 @@ class JsonTool(BuiltinTool):
"result_size": len(merged_json),
"merged_data": merged_json
}
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_json_path( input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取JSON路径"""
json_path = kwargs.get("json_path")
if not json_path:

View File

@@ -275,8 +275,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_formula_result( result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化公式识别结果"""
formulas = result.get("formulas", [])
@@ -288,8 +289,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_table_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化表格识别结果"""
tables = result.get("tables", [])
@@ -301,8 +303,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_document_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化文档识别结果"""
return {
"recognition_mode": "document",
@@ -314,8 +317,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@staticmethod
def _group_lines_to_paragraphs(lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将行分组为段落"""
paragraphs = []
current_paragraph = []

View File

@@ -1,485 +0,0 @@
"""工具链管理器 - 支持langchain的工具链模式"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from app.core.tools.base import ToolResult
from app.core.tools.executor import ToolExecutor
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ChainExecutionMode(str, Enum):
"""链执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
CONDITIONAL = "conditional" # 条件执行
@dataclass
class ChainStep:
"""链步骤定义"""
tool_id: str
parameters: Dict[str, Any]
condition: Optional[str] = None # 执行条件
output_mapping: Optional[Dict[str, str]] = None # 输出映射
error_handling: str = "stop" # 错误处理stop, continue, retry
@dataclass
class ChainDefinition:
"""工具链定义"""
name: str
description: str
steps: List[ChainStep]
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
global_timeout: Optional[float] = None
retry_policy: Optional[Dict[str, Any]] = None
class ChainExecutionContext:
"""链执行上下文"""
def __init__(self, chain_id: str):
self.chain_id = chain_id
self.variables: Dict[str, Any] = {}
self.step_results: Dict[int, ToolResult] = {}
self.current_step = 0
self.is_completed = False
self.is_failed = False
self.error_message: Optional[str] = None
class ChainManager:
"""工具链管理器 - 支持langchain的工具链模式"""
def __init__(self, executor: ToolExecutor):
"""初始化工具链管理器
Args:
executor: 工具执行器
"""
self.executor = executor
self._chains: Dict[str, ChainDefinition] = {}
self._running_chains: Dict[str, ChainExecutionContext] = {}
def register_chain(self, chain: ChainDefinition) -> bool:
"""注册工具链
Args:
chain: 工具链定义
Returns:
注册是否成功
"""
try:
# 验证工具链定义
validation_result = self._validate_chain(chain)
if not validation_result[0]:
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
return False
self._chains[chain.name] = chain
logger.info(f"工具链注册成功: {chain.name}")
return True
except Exception as e:
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
return False
def unregister_chain(self, chain_name: str) -> bool:
"""注销工具链
Args:
chain_name: 工具链名称
Returns:
注销是否成功
"""
if chain_name in self._chains:
del self._chains[chain_name]
logger.info(f"工具链注销成功: {chain_name}")
return True
return False
def list_chains(self) -> List[Dict[str, Any]]:
"""列出所有工具链
Returns:
工具链信息列表
"""
chains = []
for name, chain in self._chains.items():
chains.append({
"name": name,
"description": chain.description,
"step_count": len(chain.steps),
"execution_mode": chain.execution_mode.value,
"global_timeout": chain.global_timeout
})
return chains
async def execute_chain(
self,
chain_name: str,
initial_variables: Optional[Dict[str, Any]] = None,
chain_id: Optional[str] = None
) -> Dict[str, Any] | None:
"""执行工具链
Args:
chain_name: 工具链名称
initial_variables: 初始变量
chain_id: 链执行ID可选
Returns:
执行结果
"""
if chain_name not in self._chains:
return {
"success": False,
"error": f"工具链不存在: {chain_name}",
"chain_id": chain_id
}
chain = self._chains[chain_name]
# 生成链ID
if not chain_id:
import uuid
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ChainExecutionContext(chain_id)
context.variables = initial_variables or {}
self._running_chains[chain_id] = context
try:
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
# 根据执行模式执行
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
result = await self._execute_sequential(chain, context)
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
result = await self._execute_parallel(chain, context)
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
result = await self._execute_conditional(chain, context)
else:
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
return result
except Exception as e:
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
return {
"success": False,
"error": str(e),
"chain_id": chain_id,
"completed_steps": context.current_step,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
finally:
# 清理执行上下文
if chain_id in self._running_chains:
del self._running_chains[chain_id]
async def _execute_sequential(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""顺序执行工具链"""
for i, step in enumerate(chain.steps):
context.current_step = i
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
logger.debug(f"跳过步骤 {i}: 条件不满足")
continue
# 准备参数
parameters = self._prepare_parameters(step.parameters, context)
# 执行工具
try:
result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = result
# 处理输出映射
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 处理执行失败
if not result.success:
if step.error_handling == "stop":
context.is_failed = True
context.error_message = result.error
break
elif step.error_handling == "continue":
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
continue
elif step.error_handling == "retry":
# 简单重试逻辑
retry_result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = retry_result
if not retry_result.success and step.error_handling == "stop":
context.is_failed = True
context.error_message = retry_result.error
break
except Exception as e:
logger.error(f"步骤 {i} 执行异常: {e}")
if step.error_handling == "stop":
context.is_failed = True
context.error_message = str(e)
break
context.is_completed = not context.is_failed
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": context.current_step + 1,
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_parallel(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""并行执行工具链"""
# 准备所有步骤的执行配置
execution_configs = []
for i, step in enumerate(chain.steps):
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
continue
parameters = self._prepare_parameters(step.parameters, context)
execution_configs.append({
"step_index": i,
"tool_id": step.tool_id,
"parameters": parameters
})
# 并行执行所有步骤
try:
results = await self.executor.execute_tools_batch(execution_configs)
# 处理结果
for i, result in enumerate(results):
step_index = execution_configs[i]["step_index"]
context.step_results[step_index] = result
# 处理输出映射
step = chain.steps[step_index]
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 检查是否有失败的步骤
failed_steps = [i for i, result in context.step_results.items() if not result.success]
context.is_completed = len(failed_steps) == 0
if failed_steps:
context.error_message = f"步骤 {failed_steps} 执行失败"
except Exception as e:
context.is_failed = True
context.error_message = str(e)
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": len(context.step_results),
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_conditional(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""条件执行工具链"""
# 条件执行类似于顺序执行,但更严格地检查条件
return await self._execute_sequential(chain, context)
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
"""验证工具链定义
Args:
chain: 工具链定义
Returns:
(是否有效, 错误信息)
"""
if not chain.name:
return False, "工具链名称不能为空"
if not chain.steps:
return False, "工具链必须包含至少一个步骤"
for i, step in enumerate(chain.steps):
if not step.tool_id:
return False, f"步骤 {i} 缺少工具ID"
if step.error_handling not in ["stop", "continue", "retry"]:
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
return True, None
def _prepare_parameters(
self,
parameters: Dict[str, Any],
context: ChainExecutionContext
) -> Dict[str, Any]:
"""准备参数(支持变量替换)
Args:
parameters: 原始参数
context: 执行上下文
Returns:
处理后的参数
"""
prepared = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_name = value[2:-1]
if var_name in context.variables:
prepared[key] = context.variables[var_name]
else:
prepared[key] = value # 保持原值
else:
prepared[key] = value
return prepared
def _evaluate_condition(
self,
condition: str,
context: ChainExecutionContext
) -> bool:
"""评估执行条件
Args:
condition: 条件表达式
context: 执行上下文
Returns:
条件是否满足
"""
try:
# 简单的条件评估(可以扩展为更复杂的表达式解析)
# 支持格式variable == value, variable != value, variable > value 等
if "==" in condition:
var_name, expected_value = condition.split("==", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) == expected_value
elif "!=" in condition:
var_name, expected_value = condition.split("!=", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) != expected_value
elif condition in context.variables:
# 简单的布尔检查
return bool(context.variables[condition])
else:
# 默认为真
return True
except Exception as e:
logger.error(f"条件评估失败: {condition}, 错误: {e}")
return False
def _apply_output_mapping(
self,
mapping: Dict[str, str],
output_data: Any,
context: ChainExecutionContext
):
"""应用输出映射
Args:
mapping: 输出映射配置
output_data: 输出数据
context: 执行上下文
"""
try:
if isinstance(output_data, dict):
for source_key, target_var in mapping.items():
if source_key in output_data:
context.variables[target_var] = output_data[source_key]
else:
# 如果输出不是字典,将整个输出映射到指定变量
if "result" in mapping:
context.variables[mapping["result"]] = output_data
except Exception as e:
logger.error(f"输出映射失败: {e}")
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
"""序列化工具结果
Args:
result: 工具结果
Returns:
序列化的结果
"""
return {
"success": result.success,
"data": result.data,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time,
"token_usage": result.token_usage,
"metadata": result.metadata
}
def get_running_chains(self) -> List[Dict[str, Any]]:
"""获取正在运行的工具链
Returns:
运行中的工具链列表
"""
chains = []
for chain_id, context in self._running_chains.items():
chains.append({
"chain_id": chain_id,
"current_step": context.current_step,
"is_completed": context.is_completed,
"is_failed": context.is_failed,
"variables_count": len(context.variables),
"completed_steps": len(context.step_results)
})
return chains

View File

@@ -1,264 +0,0 @@
"""工具配置管理器 - 管理工具配置的加载和验证"""
import json
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, ValidationError
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ToolConfigSchema(BaseModel):
"""工具配置基础Schema"""
name: str
description: str
tool_type: str
version: str = "1.0.0"
enabled: bool = True
parameters: Dict[str, Any] = {}
tags: list[str] = []
class Config:
extra = "allow"
class BuiltinToolConfigSchema(ToolConfigSchema):
"""内置工具配置Schema"""
tool_class: str
tool_type: str = "builtin"
class CustomToolConfigSchema(ToolConfigSchema):
"""自定义工具配置Schema"""
schema_url: Optional[str] = None
schema_content: Optional[Dict[str, Any]] = None
auth_type: str = "none"
auth_config: Dict[str, Any] = {}
base_url: Optional[str] = None
timeout: int = 30
tool_type: str = "custom"
class MCPToolConfigSchema(ToolConfigSchema):
"""MCP工具配置Schema"""
server_url: str
connection_config: Dict[str, Any] = {}
available_tools: list[str] = []
tool_type: str = "mcp"
class ConfigManager:
"""工具配置管理器"""
def __init__(self, config_dir: Optional[str] = None):
"""初始化配置管理器
Args:
config_dir: 配置文件目录,默认使用系统配置
"""
self.config_dir = Path(config_dir or self._get_default_config_dir())
self.config_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
def _get_default_config_dir(self) -> str:
"""获取默认配置目录"""
# 获取tools目录下的configs子目录
tools_dir = Path(__file__).parent
return str(tools_dir / "configs")
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
"""加载内置工具配置
Returns:
内置工具配置字典
"""
configs = {}
builtin_dir = self.config_dir / "builtin"
if not builtin_dir.exists():
logger.info("内置工具配置目录不存在,创建默认配置")
self._create_default_builtin_configs(builtin_dir)
for config_file in builtin_dir.glob("*.json"):
try:
config_data = self._load_config_file(config_file)
config = BuiltinToolConfigSchema(**config_data)
configs[config.name] = config
logger.debug(f"加载内置工具配置: {config.name}")
except Exception as e:
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
return configs
def load_builtin_tools_config(self) -> Dict[str, Any]:
"""加载全局内置工具配置(兼容原有接口)
Returns:
内置工具配置字典
"""
config_file = self.config_dir / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
"""确保内置工具已初始化到数据库
Args:
tenant_id: 租户ID
db_session: 数据库会话
tool_config_model: ToolConfig模型类
builtin_tool_config_model: BuiltinToolConfig模型类
tool_type_enum: ToolType枚举
tool_status_enum: ToolStatus枚举
"""
# 检查是否已初始化
existing_count = db_session.query(tool_config_model).filter(
tool_config_model.tenant_id == tenant_id,
tool_config_model.tool_type == tool_type_enum.BUILTIN
).count()
if existing_count > 0:
return # 已初始化
# 加载全局配置
builtin_tools = self.load_builtin_tools_config()
# 为租户创建内置工具记录
for tool_key, tool_info in builtin_tools.items():
# 设置初始状态
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
tool_config = tool_config_model(
name=tool_info['name'],
description=tool_info['description'],
tool_type=tool_type_enum.BUILTIN,
tenant_id=tenant_id,
status=initial_status
)
db_session.add(tool_config)
db_session.flush()
builtin_config = builtin_tool_config_model(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={}
)
db_session.add(builtin_config)
db_session.commit()
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
"""保存工具配置
Args:
config: 工具配置
tool_type: 工具类型
Returns:
保存是否成功
"""
try:
config_dir = self.config_dir / tool_type
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{config.name}.json"
config_data = config.model_dump()
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2, ensure_ascii=False)
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
return True
except Exception as e:
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
return False
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
"""删除工具配置
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
删除是否成功
"""
try:
config_file = self.config_dir / tool_type / f"{tool_name}.json"
if config_file.exists():
config_file.unlink()
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
return True
else:
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
return False
except Exception as e:
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
return False
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
"""验证工具配置
Args:
config_data: 配置数据
tool_type: 工具类型
Returns:
(是否有效, 错误信息)
"""
try:
schema_map = {
"builtin": BuiltinToolConfigSchema,
"custom": CustomToolConfigSchema,
"mcp": MCPToolConfigSchema
}
schema_class = schema_map.get(tool_type)
if not schema_class:
return False, f"不支持的工具类型: {tool_type}"
# 验证配置
schema_class(**config_data)
return True, None
except ValidationError as e:
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
return False, f"配置验证失败: {error_msg}"
except Exception as e:
return False, f"配置验证异常: {str(e)}"
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
"""加载配置文件
Args:
config_file: 配置文件路径
Returns:
配置数据字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
raise
def _create_default_builtin_configs(self, builtin_dir: Path):
"""创建默认内置工具配置
Args:
builtin_dir: 内置工具配置目录
"""
builtin_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
# 配置文件已经通过其他方式创建,这里只需要确保目录存在

View File

@@ -54,7 +54,8 @@
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"base_url": {"type": "string", "description": "API地址", "default": "https://api.textin.com/v1"}
}
}
}

View File

@@ -2,7 +2,6 @@
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
@@ -51,8 +50,9 @@ class AuthManager:
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
@staticmethod
def _validate_api_key_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
@@ -79,8 +79,9 @@ class AuthManager:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
@staticmethod
def _validate_bearer_token_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
@@ -135,9 +136,9 @@ class AuthManager:
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
@staticmethod
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
@@ -176,9 +177,9 @@ class AuthManager:
headers["Cookie"] = cookie_value
return url, headers, params
@staticmethod
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
@@ -260,8 +261,9 @@ class AuthManager:
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
@staticmethod
def _encrypt_string(value: str, key: str) -> str:
"""加密字符串
Args:
@@ -289,8 +291,9 @@ class AuthManager:
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
@staticmethod
def _decrypt_string(encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
@@ -471,8 +474,9 @@ class AuthManager:
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
@staticmethod
def get_auth_config_template(auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
@@ -498,8 +502,9 @@ class AuthManager:
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def mask_sensitive_config(auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:

View File

@@ -5,7 +5,8 @@ import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -173,8 +174,9 @@ class CustomTool(BaseTool):
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
@staticmethod
def _convert_openapi_type(openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
@@ -239,8 +241,9 @@ class CustomTool(BaseTool):
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
@staticmethod
def _build_request_data(operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
@@ -284,6 +287,7 @@ class CustomTool(BaseTool):
try:
return await response.json()
except Exception as e:
logger.error(f"解析HTTP响应JSON失败: {str(e)}")
return await response.text()
@classmethod

View File

@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -88,8 +91,9 @@ class OpenAPISchemaParser:
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
@staticmethod
def _parse_content(content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
@@ -101,7 +105,7 @@ class OpenAPISchemaParser:
"""
try:
# 根据内容类型解析
if 'json' in content_type:
if 'application/json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
@@ -228,8 +232,9 @@ class OpenAPISchemaParser:
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_parameters(operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
@@ -266,8 +271,9 @@ class OpenAPISchemaParser:
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
@staticmethod
def _extract_request_body(operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
@@ -298,8 +304,9 @@ class OpenAPISchemaParser:
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_responses(operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
@@ -331,8 +338,9 @@ class OpenAPISchemaParser:
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
@staticmethod
def generate_tool_parameters(operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
@@ -396,7 +404,7 @@ class OpenAPISchemaParser:
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
@@ -447,8 +455,9 @@ class OpenAPISchemaParser:
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
@staticmethod
def _validate_parameter_type(value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
@@ -474,4 +483,7 @@ class OpenAPISchemaParser:
if expected_python_type:
return isinstance(value, expected_python_type)
return True
return True
# 为了兼容性,创建别名
SchemaParser = OpenAPISchemaParser

View File

@@ -1,501 +0,0 @@
"""工具执行器 - 负责工具的实际调用和执行管理"""
import asyncio
import uuid
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import ToolExecution, ExecutionStatus
from app.core.tools.base import BaseTool, ToolResult
from app.core.tools.registry import ToolRegistry
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ExecutionContext:
"""执行上下文"""
def __init__(
self,
execution_id: str,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.execution_id = execution_id
self.tool_id = tool_id
self.user_id = user_id
self.workspace_id = workspace_id
self.timeout = timeout or 60.0 # 默认60秒超时
self.metadata = metadata or {}
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self.status = ExecutionStatus.PENDING
class ToolExecutor:
"""工具执行器 - 使用langchain标准接口执行工具"""
def __init__(self, db: Session, registry: ToolRegistry):
"""初始化工具执行器
Args:
db: 数据库会话
registry: 工具注册表
"""
self.db = db
self.registry = registry
self._running_executions: Dict[str, ExecutionContext] = {}
self._execution_lock = asyncio.Lock()
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
execution_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ToolResult:
"""执行工具
Args:
tool_id: 工具ID
parameters: 工具参数
user_id: 用户ID
workspace_id: 工作空间ID
execution_id: 执行ID可选自动生成
timeout: 超时时间(秒)
metadata: 额外元数据
Returns:
工具执行结果
"""
# 生成执行ID
if not execution_id:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ExecutionContext(
execution_id=execution_id,
tool_id=tool_id,
user_id=user_id,
workspace_id=workspace_id,
timeout=timeout,
metadata=metadata
)
try:
# 获取工具实例
tool = self.registry.get_tool(tool_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
error_code="TOOL_NOT_FOUND",
execution_time=0.0
)
# 记录执行开始
await self._record_execution_start(context, parameters)
# 执行工具
result = await self._execute_with_timeout(tool, parameters, context)
# 记录执行完成
await self._record_execution_complete(context, result)
return result
except Exception as e:
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
# 记录执行失败
error_result = ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=time.time() - context.started_at.timestamp()
)
await self._record_execution_complete(context, error_result)
return error_result
finally:
# 清理执行上下文
async with self._execution_lock:
if execution_id in self._running_executions:
del self._running_executions[execution_id]
async def execute_tools_batch(
self,
tool_executions: List[Dict[str, Any]],
max_concurrency: int = 5
) -> List[ToolResult]:
"""批量执行工具
Args:
tool_executions: 工具执行配置列表每个包含tool_id和parameters
max_concurrency: 最大并发数
Returns:
执行结果列表
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
async with semaphore:
return await self.execute_tool(
tool_id=exec_config["tool_id"],
parameters=exec_config.get("parameters", {}),
user_id=exec_config.get("user_id"),
workspace_id=exec_config.get("workspace_id"),
timeout=exec_config.get("timeout"),
metadata=exec_config.get("metadata")
)
# 并发执行所有工具
tasks = [execute_single(config) for config in tool_executions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(
ToolResult.error_result(
error=str(result),
error_code="BATCH_EXECUTION_ERROR",
execution_time=0.0
)
)
else:
processed_results.append(result)
return processed_results
async def cancel_execution(self, execution_id: str) -> bool:
"""取消工具执行
Args:
execution_id: 执行ID
Returns:
是否成功取消
"""
async with self._execution_lock:
if execution_id not in self._running_executions:
return False
context = self._running_executions[execution_id]
context.status = ExecutionStatus.FAILED
# 更新数据库记录
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution_record:
execution_record.status = ExecutionStatus.FAILED.value
execution_record.error_message = "执行被取消"
execution_record.completed_at = datetime.now()
self.db.commit()
logger.info(f"工具执行已取消: {execution_id}")
return True
def get_running_executions(self) -> List[Dict[str, Any]]:
"""获取正在运行的执行列表
Returns:
执行信息列表
"""
executions = []
for execution_id, context in self._running_executions.items():
executions.append({
"execution_id": execution_id,
"tool_id": context.tool_id,
"user_id": str(context.user_id) if context.user_id else None,
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
"started_at": context.started_at.isoformat(),
"status": context.status.value,
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
})
return executions
async def _execute_with_timeout(
self,
tool: BaseTool,
parameters: Dict[str, Any],
context: ExecutionContext
) -> ToolResult:
"""带超时的工具执行
Args:
tool: 工具实例
parameters: 参数
context: 执行上下文
Returns:
执行结果
"""
async with self._execution_lock:
self._running_executions[context.execution_id] = context
context.status = ExecutionStatus.RUNNING
try:
# 使用asyncio.wait_for实现超时控制
result = await asyncio.wait_for(
tool.safe_execute(**parameters),
timeout=context.timeout
)
context.status = ExecutionStatus.COMPLETED
return result
except asyncio.TimeoutError:
context.status = ExecutionStatus.TIMEOUT
return ToolResult.error_result(
error=f"工具执行超时({context.timeout}秒)",
error_code="EXECUTION_TIMEOUT",
execution_time=context.timeout
)
except Exception as e:
context.status = ExecutionStatus.FAILED
raise
async def _record_execution_start(
self,
context: ExecutionContext,
parameters: Dict[str, Any]
):
"""记录执行开始"""
try:
execution_record = ToolExecution(
execution_id=context.execution_id,
tool_config_id=uuid.UUID(context.tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=context.started_at,
user_id=context.user_id,
workspace_id=context.workspace_id
)
self.db.add(execution_record)
self.db.commit()
logger.debug(f"执行记录已创建: {context.execution_id}")
except Exception as e:
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
async def _record_execution_complete(
self,
context: ExecutionContext,
result: ToolResult
):
"""记录执行完成"""
try:
context.completed_at = datetime.now()
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == context.execution_id
).first()
if execution_record:
execution_record.status = (
ExecutionStatus.COMPLETED.value if result.success
else ExecutionStatus.FAILED.value
)
execution_record.output_data = result.data if result.success else None
execution_record.error_message = result.error if not result.success else None
execution_record.completed_at = context.completed_at
execution_record.execution_time = result.execution_time
execution_record.token_usage = result.token_usage
self.db.commit()
logger.debug(f"执行记录已更新: {context.execution_id}")
except Exception as e:
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
def get_execution_history(
self,
tool_id: Optional[str] = None,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""获取执行历史
Args:
tool_id: 工具ID过滤
user_id: 用户ID过滤
workspace_id: 工作空间ID过滤
limit: 返回数量限制
Returns:
执行历史列表
"""
try:
query = self.db.query(ToolExecution).order_by(
ToolExecution.started_at.desc()
)
if tool_id:
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
if user_id:
query = query.filter(ToolExecution.user_id == user_id)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.limit(limit).all()
history = []
for execution in executions:
history.append({
"execution_id": execution.execution_id,
"tool_id": str(execution.tool_config_id),
"status": execution.status,
"started_at": execution.started_at.isoformat() if execution.started_at else None,
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
"execution_time": execution.execution_time,
"user_id": str(execution.user_id) if execution.user_id else None,
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
"input_data": execution.input_data,
"output_data": execution.output_data,
"error_message": execution.error_message,
"token_usage": execution.token_usage
})
return history
except Exception as e:
logger.error(f"获取执行历史失败, 错误: {e}")
return []
def get_execution_statistics(
self,
workspace_id: Optional[uuid.UUID] = None,
days: int = 7
) -> Dict[str, Any]:
"""获取执行统计信息
Args:
workspace_id: 工作空间ID
days: 统计天数
Returns:
统计信息
"""
try:
from datetime import timedelta
start_date = datetime.now() - timedelta(days=days)
query = self.db.query(ToolExecution).filter(
ToolExecution.started_at >= start_date
)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.all()
# 统计数据
total_executions = len(executions)
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
# 平均执行时间
completed_executions = [e for e in executions if e.execution_time is not None]
avg_execution_time = (
sum(e.execution_time for e in completed_executions) / len(completed_executions)
if completed_executions else 0
)
# 按工具统计
tool_stats = {}
for execution in executions:
tool_id = str(execution.tool_config_id)
if tool_id not in tool_stats:
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
tool_stats[tool_id]["total"] += 1
if execution.status == ExecutionStatus.COMPLETED.value:
tool_stats[tool_id]["successful"] += 1
elif execution.status == ExecutionStatus.FAILED.value:
tool_stats[tool_id]["failed"] += 1
return {
"period_days": days,
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_execution_time": avg_execution_time,
"tool_statistics": tool_stats
}
except Exception as e:
logger.error(f"获取执行统计失败, 错误: {e}")
return {}
async def test_tool_connection(
self,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
from .mcp.client import MCPClient
tool_config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id)
).first()
if not tool_config:
return {"success": False, "message": "工具不存在"}
if tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
}
except:
await client.disconnect()
return {"success": False, "message": "MCP功能测试失败"}
else:
return {"success": False, "message": "MCP连接失败"}
else:
tool = self.registry.get_tool(tool_id)
if tool and hasattr(tool, 'test_connection'):
result = tool.test_connection()
return {"success": result.get("success", False), "message": result.get("message", "")}
return {"success": True, "message": "工具无需连接测试"}
except Exception as e:
return {"success": False, "message": "测试失败", "error": str(e)}

View File

@@ -4,7 +4,8 @@ from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
from .client import MCPClient
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def _update_available_tools(self):
"""更新可用工具列表"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
await self._client.disconnect()
self._client = None
self._connected = False
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []

View File

@@ -134,11 +134,40 @@ class MCPClient:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
auth_config = self.connection_config.get("auth_config", {})
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
@@ -190,6 +219,8 @@ class MCPClient:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
@@ -251,8 +282,9 @@ class MCPClient:
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
@@ -327,7 +359,7 @@ class MCPClient:
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
@@ -372,10 +404,10 @@ class MCPClient:
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
@@ -424,9 +456,9 @@ class MCPClient:
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = time.time()
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,

View File

@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
@@ -148,7 +148,7 @@ class MCPServiceManager:
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
last_health_check=datetime.now()
)
self.db.add(mcp_config)
@@ -410,7 +410,8 @@ class MCPServiceManager:
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
@@ -531,7 +532,7 @@ class MCPServiceManager:
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")

View File

@@ -1,436 +0,0 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")

View File

@@ -3,7 +3,7 @@ import uuid
from datetime import datetime
from enum import StrEnum
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -19,10 +19,40 @@ class ToolType(StrEnum):
class ToolStatus(StrEnum):
"""工具状态枚举"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
LOADING = "loading"
AVAILABLE = "available" # 可用(已配置且已启用)
UNCONFIGURED = "unconfigured" # 未配置
CONFIGURED_DISABLED = "configured_disabled" # 已配置未启用
ERROR = "error" # 错误状态
@classmethod
def get_all_statuses(cls):
"""获取所有工具状态"""
return [status.value for status in cls]
@classmethod
def get_all_statuses_with_labels(cls):
"""获取所有工具状态及其文本描述"""
return [
{"value": cls.AVAILABLE.value, "label": "可用"},
{"value": cls.UNCONFIGURED.value, "label": "未配置"},
{"value": cls.CONFIGURED_DISABLED.value, "label": "已配置未启用"},
{"value": cls.ERROR.value, "label": "错误状态"}
]
@classmethod
def is_valid_status(cls, status):
"""检查状态是否有效"""
return status in cls._value2member_map_
@classmethod
def get_active_statuses(cls):
"""获取所有活跃状态"""
return [cls.AVAILABLE.value]
@classmethod
def get_inactive_statuses(cls):
"""获取所有非活跃状态"""
return [cls.UNCONFIGURED.value, cls.CONFIGURED_DISABLED.value, cls.ERROR.value]
class AuthType(StrEnum):
@@ -30,6 +60,27 @@ class AuthType(StrEnum):
NONE = "none"
API_KEY = "api_key"
BEARER_TOKEN = "bearer_token"
BASIC_AUTH = "basic_auth"
@classmethod
def get_all_types(cls):
"""获取所有认证类型"""
return [auth_type.value for auth_type in cls]
@classmethod
def get_all_types_with_labels(cls):
"""获取所有认证类型及其文本描述"""
return [
{"value": cls.NONE.value, "label": "无需认证"},
{"value": cls.API_KEY.value, "label": "API Key"},
{"value": cls.BEARER_TOKEN.value, "label": "Bearer Token"},
{"value": cls.BASIC_AUTH.value, "label": "Basic Auth"}
]
@classmethod
def is_valid_types(cls, auth_type):
"""检查认证类型是否有效"""
return auth_type in cls._value2member_map_
class ExecutionStatus(StrEnum):
@@ -48,13 +99,14 @@ class ToolConfig(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(255), nullable=False, index=True)
description = Column(Text)
icon = Column(String(255)) # 工具图标
tool_type = Column(String(50), nullable=False, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
status = Column(String(50), default=ToolStatus.UNCONFIGURED.value, nullable=False, index=True) # 工具状态
# 工具特定配置JSON格式存储
config_data = Column(JSON, default=dict)
# 元数据
version = Column(String(50), default="1.0.0")
tags = Column(JSON, default=list) # 标签列表
@@ -78,12 +130,14 @@ class BuiltinToolConfig(Base):
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
tool_class = Column(String(255), nullable=False) # 工具类名
parameters = Column(JSON, default=dict) # 工具参数配置
is_enabled = Column(Boolean, default=False, nullable=False) # 启用开关
requires_config = Column(Boolean, default=False, nullable=False) # 是否需要配置
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class}, enabled={self.is_enabled})>"
class CustomToolConfig(Base):
@@ -115,7 +169,7 @@ class MCPToolConfig(Base):
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
server_url = Column(String(1000), nullable=False) # MCP服务器URL
connection_config = Column(JSON, default=dict) # 连接配置
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
# 服务状态
last_health_check = Column(DateTime)

View File

@@ -0,0 +1,157 @@
"""工具数据访问层"""
import uuid
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from app.repositories.base_repository import BaseRepository
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus
)
class ToolRepository:
"""工具仓储类"""
@staticmethod
def find_by_tenant(
db: Session,
tenant_id: uuid.UUID,
name: Optional[str] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
is_enabled: Optional[bool] = None
) -> List[ToolConfig]:
"""根据租户查找工具"""
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id
)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status:
query = query.filter(ToolConfig.status == status.value)
if is_enabled is not None:
query = query.filter(ToolConfig.is_enabled == is_enabled)
return query.all()
@staticmethod
def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""根据ID和租户查找工具"""
return db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id
).first()
@staticmethod
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
"""统计租户工具数量"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id
).count()
@staticmethod
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取状态统计"""
return db.query(
ToolConfig.status,
func.count(ToolConfig.id).label('count')
).filter(
ToolConfig.tenant_id == tenant_id
).group_by(ToolConfig.status).all()
@staticmethod
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取类型统计"""
return db.query(
ToolConfig.tool_type,
func.count(ToolConfig.id).label('count')
).filter(
ToolConfig.tenant_id == tenant_id
).group_by(ToolConfig.tool_type).all()
@staticmethod
def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
"""统计租户启用的工具数量"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_enabled == True
).count()
@staticmethod
def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool:
"""检查租户是否已有内置工具"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == ToolType.BUILTIN.value
).count() > 0
class BuiltinToolRepository:
"""内置工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]:
"""根据工具ID查找内置工具配置"""
return db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_id
).first()
class CustomToolRepository:
"""自定义工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]:
"""根据工具ID查找自定义工具配置"""
return db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_id
).first()
class MCPToolRepository:
"""MCP工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]:
"""根据工具ID查找MCP工具配置"""
return db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_id
).first()
@staticmethod
def find_error_connections(db: Session) -> List[MCPToolConfig]:
"""查找连接错误的MCP工具"""
return db.query(MCPToolConfig).filter(
MCPToolConfig.connection_status == "error"
).all()
class ToolExecutionRepository:
"""工具执行仓储类"""
@staticmethod
def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]:
"""根据执行ID查找执行记录"""
return db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
@staticmethod
def find_by_tool_and_tenant(
db: Session,
tool_id: uuid.UUID,
tenant_id: uuid.UUID,
limit: int = 100
) -> List[ToolExecution]:
"""根据工具和租户查找执行记录"""
return db.query(ToolExecution).join(
ToolConfig, ToolExecution.tool_config_id == ToolConfig.id
).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id
).order_by(ToolExecution.started_at.desc()).limit(limit).all()

View File

@@ -0,0 +1,259 @@
"""工具相关的数据模式定义"""
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field, field_serializer
from datetime import datetime
from enum import Enum
from app.core.api_key_utils import datetime_to_timestamp
from app.models.tool_model import ToolType, ToolStatus, AuthType
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[float] = Field(None, description="最小值")
maximum: Optional[float] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
icon: Optional[str] = Field(None, description="工具图标")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
created_at: datetime = Field(..., description="创建时间")
class Config:
use_enum_values = True
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v):
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ToolConfigSchema(BaseModel):
"""工具配置基础模式"""
id: str
name: str
description: Optional[str] = None
icon: Optional[str] = None
tool_type: ToolType
status: ToolStatus
config_data: Dict[str, Any] = Field(default_factory=dict)
version: str = "1.0.0"
tags: List[str] = Field(default_factory=list)
tenant_id: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class BuiltinToolConfigSchema(BaseModel):
"""内置工具配置模式"""
tool_class: str
parameters: Dict[str, Any] = Field(default_factory=dict)
is_enabled: bool
requires_config: bool = False
class Config:
from_attributes = True
class CustomToolConfigSchema(BaseModel):
"""自定义工具配置模式"""
base_url: Optional[str] = None
auth_type: AuthType = AuthType.NONE
auth_config: Dict[str, Any] = Field(default_factory=dict)
timeout: int = 30
schema_content: Optional[Dict[str, Any]] = None
schema_url: Optional[str] = None
class Config:
from_attributes = True
class MCPToolConfigSchema(BaseModel):
"""MCP工具配置模式"""
server_url: str
connection_config: Dict[str, Any] = Field(default_factory=dict)
last_health_check: Optional[datetime] = None
health_status: str = "unknown"
error_message: Optional[str] = None
available_tools: List[str] = Field(default_factory=list)
class Config:
from_attributes = True
class ToolDetailSchema(ToolConfigSchema):
"""工具详情模式(包含类型特定配置)"""
builtin_config: Optional[BuiltinToolConfigSchema] = None
custom_config: Optional[CustomToolConfigSchema] = None
mcp_config: Optional[MCPToolConfigSchema] = None
class ToolExecutionSchema(BaseModel):
"""工具执行记录模式"""
id: str
execution_id: str
status: str
input_data: Optional[Dict[str, Any]] = None
output_data: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
started_at: datetime
completed_at: Optional[datetime] = None
execution_time: Optional[float] = None
token_usage: Optional[Dict[str, int]] = None
class Config:
from_attributes = True
class ToolCreateRequest(BaseModel):
"""创建工具请求"""
name: str = Field(..., min_length=1, max_length=255)
description: Optional[str] = Field(None, max_length=1000)
icon: Optional[str] = Field(None, max_length=255)
tool_type: ToolType
config: Dict[str, Any] = Field(default_factory=dict)
class ToolUpdateRequest(BaseModel):
"""更新工具请求"""
name: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = Field(None, max_length=1000)
icon: Optional[str] = Field(None, max_length=255)
config: Optional[Dict[str, Any]] = None
is_enabled: Optional[bool] = None
class ToolExecuteRequest(BaseModel):
"""执行工具请求"""
tool_id: str
parameters: Dict[str, Any] = Field(default_factory=dict)
timeout: Optional[float] = Field(60.0, gt=0, le=300)
class CustomToolCreateRequest(BaseModel):
"""创建自定义工具请求"""
name: str = Field(..., min_length=1, max_length=255)
description: Optional[str] = Field(None, max_length=1000)
icon: Optional[str] = Field(None, max_length=255)
auth_type: AuthType = Field(AuthType.NONE, description="认证类型")
auth_config: Dict[str, Any] = Field(default_factory=dict, description="认证配置")
timeout: int = Field(30, ge=1, le=300, description="超时时间")
schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容")
schema_url: Optional[str] = Field(None, description="OpenAPI schema URL")
class ParseSchemaRequest(BaseModel):
"""解析Schema请求"""
schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容")
schema_url: Optional[str] = Field(None, description="OpenAPI schema URL")
class ToolListQuery(BaseModel):
"""工具列表查询参数"""
name: Optional[str] = None
tool_type: Optional[ToolType] = None
status: Optional[ToolStatus] = None
is_enabled: Optional[bool] = None
page: int = Field(1, ge=1)
page_size: int = Field(20, ge=1, le=100)
class ToolStatusCount(BaseModel):
"""工具状态统计"""
status: ToolStatus
count: int
class ToolStatistics(BaseModel):
"""工具统计信息"""
total_tools: int
status_counts: List[ToolStatusCount]
type_counts: Dict[str, int]
enabled_count: int
disabled_count: int
class CustomToolTestRequest(BaseModel):
"""自定义工具测试请求"""
method: str = Field(..., description="HTTP方法")
path: str = Field(..., description="API路径")
parameters: Dict[str, Any] = Field(default_factory=dict, description="请求参数")

View File

@@ -0,0 +1,977 @@
"""工具服务 - 统一的工具管理和执行服务"""
import json
import uuid
import time
import importlib
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.core.tools.mcp import MCPClient
from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository,
MCPToolRepository, ToolExecutionRepository
)
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus, ExecutionStatus, AuthType
)
from app.schemas.tool_schema import ToolInfo, ToolResult
from app.core.logging_config import get_business_logger
from app.core.tools.base import BaseTool
from app.core.tools.custom.base import CustomTool
from app.core.tools.mcp.base import MCPTool
logger = get_business_logger()
# 内置工具映射
BUILTIN_TOOLS = {
"DateTimeTool": "app.core.tools.builtin.datetime_tool",
"JsonTool": "app.core.tools.builtin.json_tool",
"BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool",
"MinerUTool": "app.core.tools.builtin.mineru_tool",
"TextInTool": "app.core.tools.builtin.textin_tool"
}
class ToolService:
"""统一工具服务 - 管理工具的完整生命周期"""
def __init__(self, db: Session):
self.db = db
self._tool_cache: Dict[str, BaseTool] = {}
# 初始化仓储
self.tool_repo = ToolRepository()
self.builtin_repo = BuiltinToolRepository()
self.custom_repo = CustomToolRepository()
self.mcp_repo = MCPToolRepository()
self.execution_repo = ToolExecutionRepository()
def list_tools(
self,
tenant_id: uuid.UUID,
name: Optional[str] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None
) -> List[ToolInfo]:
"""获取工具列表"""
try:
configs = self.tool_repo.find_by_tenant(
db=self.db,
tenant_id=tenant_id,
name=name,
tool_type=tool_type,
status=status
)
return [self._config_to_info(config) for config in configs]
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
return []
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
"""获取工具详情"""
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
return self._config_to_info(config) if config else None
def create_tool(
self,
name: str,
tool_type: ToolType,
tenant_id: uuid.UUID,
icon: Optional[str] = None,
description: Optional[str] = None,
config: Optional[Dict[str, Any]] = None
) -> str:
"""创建工具"""
if tool_type == ToolType.BUILTIN:
raise ValueError("内置工具不允许创建")
try:
# 创建基础配置
tool_config = ToolConfig(
name=name,
description=description,
icon=icon,
tool_type=tool_type.value,
tenant_id=tenant_id,
status=ToolStatus.AVAILABLE.value,
config_data=config or {}
)
self.db.add(tool_config)
self.db.flush()
# 创建类型特定配置
self._create_type_config(tool_config, config or {})
self.db.commit()
logger.info(f"工具创建成功: {tool_config.id}")
return str(tool_config.id)
except Exception as e:
self.db.rollback()
logger.error(f"创建工具失败: {e}")
raise
def update_tool(
self,
tool_id: str,
tenant_id: uuid.UUID,
name: Optional[str] = None,
description: Optional[str] = None,
icon: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
is_enabled: Optional[bool] = None
) -> bool:
"""更新工具"""
config_obj = self._get_tool_config(tool_id, tenant_id)
if not config_obj:
return False
if config_obj.tool_type == ToolType.BUILTIN.value:
if name or description or icon:
raise ValueError("内置工具不允许修改名称、描述和图标")
try:
if name:
config_obj.name = name
if description:
config_obj.description = description
if icon:
config_obj.icon = icon
if config:
config_obj.config_data = config.copy()
# 同步到类型表
self._sync_type_config(config_obj, config, is_enabled)
# 更新状态逻辑
self._update_tool_status(config_obj)
# 清除缓存
self._clear_tool_cache(tool_id)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"更新工具失败: {tool_id}, {e}")
return False
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
"""删除工具"""
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return False
if config.tool_type == ToolType.BUILTIN.value:
raise ValueError("内置工具不允许删除")
try:
# 删除关联表记录
if config.tool_type == ToolType.CUSTOM.value:
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
elif config.tool_type == ToolType.MCP.value:
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
# 删除主表记录ToolExecution会通过cascade自动删除
self.db.delete(config)
self._clear_tool_cache(tool_id)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"删除工具失败: {tool_id}, {e}")
return False
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
tenant_id: uuid.UUID,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: float = 60.0
) -> ToolResult:
"""执行工具"""
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
start_time = time.time()
try:
# 获取工具实例
tool = self._get_tool_instance(tool_id, tenant_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
execution_time=time.time() - start_time
)
# 记录执行开始
self._record_execution_start(
execution_id, tool_id, parameters, user_id, workspace_id
)
# 执行工具
result = await tool.safe_execute(**parameters)
# 记录执行完成
self._record_execution_complete(execution_id, result)
return result
except Exception as e:
execution_time = time.time() - start_time
error_result = ToolResult.error_result(
error=str(e),
execution_time=execution_time
)
self._record_execution_complete(execution_id, error_result)
return error_result
async def test_connection(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
"""测试工具连接"""
try:
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return {"success": False, "message": "工具不存在"}
if config.tool_type == ToolType.MCP.value:
return await self._test_mcp_connection(config)
elif config.tool_type == ToolType.CUSTOM.value:
return await self._test_custom_connection(config)
elif config.tool_type == ToolType.BUILTIN.value:
return await self._test_builtin_connection(config)
else:
return {"success": True, "message": "未知工具类型"}
except Exception as e:
return {"success": False, "message": f"测试失败: {str(e)}"}
def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID):
"""确保内置工具已初始化"""
existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id)
if existing:
return
# 从配置文件加载内置工具定义
builtin_config = self._load_builtin_config()
for tool_key, tool_info in builtin_config.items():
try:
# 创建工具配置
initial_status = self._determine_initial_status(tool_info)
tool_config = ToolConfig(
name=tool_info['name'],
description=tool_info['description'],
tool_type=ToolType.BUILTIN.value,
tenant_id=tenant_id,
status=initial_status,
config_data={"tool_class": tool_info['tool_class'],
"requires_config": tool_info.get('requires_config', False),
"is_enabled": False},
version=tool_info["version"]
)
self.db.add(tool_config)
self.db.flush()
# 创建内置工具配置
builtin_config_obj = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={},
requires_config=tool_info.get('requires_config', False)
)
self.db.add(builtin_config_obj)
except Exception as e:
logger.error(f"初始化内置工具失败: {tool_key}, {e}")
self.db.commit()
logger.info(f"租户 {tenant_id} 内置工具初始化完成")
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
"""获取工具统计信息"""
try:
# 总数统计
total_tools = self.tool_repo.count_by_tenant(self.db, tenant_id)
# 状态统计
status_counts = self.tool_repo.get_status_statistics(self.db, tenant_id)
# 类型统计
type_counts = self.tool_repo.get_type_statistics(self.db, tenant_id)
# 启用/禁用统计
enabled_count = self.tool_repo.count_enabled_by_tenant(self.db, tenant_id)
disabled_count = total_tools - enabled_count
return {
"total_tools": total_tools,
"status_counts": [
{"status": status, "count": count}
for status, count in status_counts
],
"type_counts": {
tool_type: count for tool_type, count in type_counts
},
"enabled_count": enabled_count,
"disabled_count": disabled_count
}
except Exception as e:
logger.error(f"获取工具统计失败: {e}")
return {
"total_tools": 0,
"status_counts": [],
"type_counts": {},
"enabled_count": 0,
"disabled_count": 0
}
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""获取工具配置"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例"""
if tool_id in self._tool_cache:
return self._tool_cache[tool_id]
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return None
try:
tool = self._create_tool_instance(config)
if tool:
self._tool_cache[tool_id] = tool
return tool
except Exception as e:
logger.error(f"创建工具实例失败: {tool_id}, {e}")
return None
def _create_tool_instance(self, config: ToolConfig) -> Optional[BaseTool]:
"""创建工具实例"""
if config.tool_type == ToolType.BUILTIN.value:
return self._create_builtin_instance(config)
elif config.tool_type == ToolType.CUSTOM.value:
return self._create_custom_instance(config)
elif config.tool_type == ToolType.MCP.value:
return self._create_mcp_instance(config)
return None
def _create_builtin_instance(self, config: ToolConfig) -> Optional[BaseTool]:
"""创建内置工具实例"""
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
return None
try:
module_path = BUILTIN_TOOLS[builtin_config.tool_class]
module = importlib.import_module(module_path)
tool_class = getattr(module, builtin_config.tool_class)
tool_config = {
**config.config_data,
"parameters": builtin_config.parameters,
}
return tool_class(str(config.id), tool_config)
except Exception as e:
logger.error(f"创建内置工具实例失败: {builtin_config.tool_class}, {e}")
return None
def _create_custom_instance(self, config: ToolConfig) -> Optional[CustomTool]:
"""创建自定义工具实例"""
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
if not custom_config:
return None
tool_config = {
"base_url": custom_config.base_url,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config or {},
"timeout": custom_config.timeout or 30,
"schema_content": custom_config.schema_content,
"schema_url": custom_config.schema_url
}
return CustomTool(str(config.id), tool_config)
def _create_mcp_instance(self, config: ToolConfig) -> Optional[MCPTool]:
"""创建MCP工具实例"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return None
tool_config = {
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"available_tools": mcp_config.available_tools or []
}
return MCPTool(str(config.id), tool_config)
def _config_to_info(self, config: ToolConfig) -> ToolInfo:
"""配置转换为信息对象"""
config_data = config.config_data or {}
# 对于MCP工具从MCPToolConfig获取额外信息
if config.tool_type == ToolType.MCP.value:
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if mcp_config:
config_data.update({
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
"health_status": mcp_config.health_status,
"available_tools": mcp_config.available_tools or []
})
return ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
icon=config.icon,
tool_type=ToolType(config.tool_type),
version=config.version or "1.0.0",
status=ToolStatus(config.status),
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None,
config_data=config_data,
created_at=config.created_at
)
def _create_type_config(self, tool_config: ToolConfig, config: Dict[str, Any]):
"""创建类型特定配置"""
if tool_config.tool_type == ToolType.CUSTOM.value:
# 从 schema 中解析 base_url
base_url = config.get("base_url")
if not base_url and (config.get("schema_content") or config.get("schema_url")):
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
if config.get("schema_content"):
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), "application/json")
else:
success, schema, _ = parser.parse_from_url(config["schema_url"])
if success:
tool_info = parser.extract_tool_info(schema)
servers = tool_info.get("servers", [])
base_url = servers[0].get("url") if servers else ""
except Exception as e:
logger.error(f"解析schema获取base_url失败: {e}")
custom_config = CustomToolConfig(
id=tool_config.id,
base_url=base_url,
auth_type=config.get("auth_type", "none"),
auth_config=config.get("auth_config", {}),
timeout=config.get("timeout", 30),
schema_content=config.get("schema_content"),
schema_url=config.get("schema_url")
)
self.db.add(custom_config)
elif tool_config.tool_type == ToolType.MCP.value:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=config.get("server_url"),
connection_config=config.get("connection_config", {}),
available_tools=config.get("available_tools", [])
)
self.db.add(mcp_config)
def _sync_type_config(self, tool_config: ToolConfig, config: Dict[str, Any], is_enabled: bool):
"""同步到类型特定表"""
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config:
builtin_config.parameters = config.get("parameters", {})
if is_enabled is not None:
builtin_config.is_enabled = is_enabled
elif tool_config.tool_type == ToolType.CUSTOM.value:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
base_url = config.get("base_url")
if not base_url and (config.get("schema_content") or config.get("schema_url")):
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
if config.get("schema_content"):
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]),
"application/json")
else:
success, schema, _ = parser.parse_from_url(config["schema_url"])
if success:
tool_info = parser.extract_tool_info(schema)
servers = tool_info.get("servers", [])
base_url = servers[0].get("url") if servers else ""
except Exception as e:
logger.error(f"解析schema获取base_url失败: {e}")
custom_config.base_url = base_url
custom_config.auth_type = config.get("auth_type", "none")
custom_config.auth_config = config.get("auth_config", {})
custom_config.timeout = config.get("timeout", 30)
custom_config.schema_content = config.get("schema_content")
custom_config.schema_url = config.get("schema_url")
elif tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
mcp_config.server_url = config.get("server_url")
mcp_config.connection_config = config.get("connection_config", {})
mcp_config.available_tools = config.get("available_tools", [])
@staticmethod
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
"""确定工具初始状态"""
if tool_info.get('requires_config', False):
return ToolStatus.UNCONFIGURED
else:
return ToolStatus.AVAILABLE
def _update_tool_status(self, tool_config: ToolConfig):
"""更新工具状态逻辑"""
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config:
if builtin_config.requires_config:
# 需要配置的工具
if self._is_tool_configured(builtin_config):
if tool_config.config_data.get("is_enabled", None):
tool_config.status = ToolStatus.AVAILABLE.value
else:
tool_config.status = ToolStatus.CONFIGURED_DISABLED.value
else:
tool_config.status = ToolStatus.UNCONFIGURED.value
else:
# 不需要配置的工具
tool_config.status = ToolStatus.AVAILABLE.value
elif tool_config.tool_type == ToolType.CUSTOM.value:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url):
tool_config.status = ToolStatus.AVAILABLE.value
else:
tool_config.status = ToolStatus.UNCONFIGURED.value
elif tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
if mcp_config.health_status == "healthy":
tool_config.status = ToolStatus.AVAILABLE.value
elif mcp_config.health_status == "error":
tool_config.status = ToolStatus.ERROR.value
else:
tool_config.status = ToolStatus.UNCONFIGURED.value
def _is_tool_configured(self, builtin_config: BuiltinToolConfig) -> bool:
"""检查工具是否已配置"""
# 从配置文件获取必需参数
builtin_config_data = self._load_builtin_config()
required_params = {}
for key, value in builtin_config_data.items():
if builtin_config.tool_class == value["tool_class"]:
required_params = value.get('parameters', {})
break
# 检查所有必需参数是否已配置
for param_name, param_info in required_params.items():
if param_info.get('required', False):
if not builtin_config.parameters.get(param_name):
return False
return True
def _clear_tool_cache(self, tool_id: str):
"""清除工具缓存"""
if tool_id in self._tool_cache:
del self._tool_cache[tool_id]
def _record_execution_start(
self,
execution_id: str,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID],
workspace_id: Optional[uuid.UUID]
):
"""记录执行开始"""
try:
execution = ToolExecution(
execution_id=execution_id,
tool_config_id=uuid.UUID(tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=datetime.now(),
user_id=user_id,
workspace_id=workspace_id
)
self.db.add(execution)
self.db.commit()
except Exception as e:
logger.error(f"记录执行开始失败: {execution_id}, {e}")
def _record_execution_complete(self, execution_id: str, result: ToolResult):
"""记录执行完成"""
try:
execution = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution:
execution.status = ExecutionStatus.COMPLETED.value if result.success else ExecutionStatus.FAILED.value
execution.output_data = result.data if result.success else None
execution.error_message = result.error if not result.success else None
execution.completed_at = datetime.now()
execution.execution_time = result.execution_time
execution.token_usage = result.token_usage
self.db.commit()
except Exception as e:
logger.error(f"记录执行完成失败: {execution_id}, {e}")
@staticmethod
def _load_builtin_config() -> Dict[str, Any]:
"""加载内置工具配置"""
import json
from pathlib import Path
config_file = Path(__file__).parent.parent / "core" / "tools" / "configs" / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试MCP连接"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
# 更新连接状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
self._update_tool_status(config)
self.db.commit()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
}
except Exception as e:
await client.disconnect()
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
else:
# 更新连接失败状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
except Exception as e:
# 更新异常状态
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
@staticmethod
async def parse_openapi_schema(schema_data: Dict[str, Any] = None, schema_url: str = None) -> Dict[str, Any]:
"""解析OpenAPI schema获取接口信息"""
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
# 使用现有的解析器
if schema_data:
success, schema, error = parser.parse_from_content(json.dumps(schema_data), "application/json")
elif schema_url:
success, schema, error = await parser.parse_from_url(schema_url)
else:
return {"success": False, "message": "schema_data或schema_url必须提供一个"}
if not success:
return {"success": False, "message": error}
# 提取工具信息
tool_info = parser.extract_tool_info(schema)
# 获取base_url
servers = tool_info.get("servers", [])
base_url = servers[0].get("url") if servers else ""
return {
"success": True,
"data": {
"title": tool_info["name"],
"description": tool_info["description"],
"version": tool_info["version"],
"base_url": base_url,
"operations": list(tool_info["operations"].values())
}
}
except Exception as e:
logger.error(f"解析OpenAPI schema失败: {e}")
return {"success": False, "message": f"解析失败: {str(e)}"}
async def sync_mcp_tools(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
"""同步MCP工具列表到数据库"""
try:
config = self._get_tool_config(tool_id, tenant_id)
if not config or config.tool_type != ToolType.MCP.value:
return {"success": False, "message": "工具不存在或不是MCP工具"}
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
# 创建MCP客户端
connection_config = mcp_config.connection_config or {}
client = MCPClient(mcp_config.server_url, connection_config)
if await client.connect():
try:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
await client.disconnect()
return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
await client.disconnect()
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
else:
# 连接失败
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
except Exception as e:
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
return {"success": False, "message": f"同步失败: {str(e)}"}
async def _test_custom_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试自定义工具连接(基础连接测试)"""
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == config.id
).first()
if not custom_config or not custom_config.base_url:
return {"success": False, "message": "自定义工具配置不完整"}
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(
custom_config.base_url,
timeout=aiohttp.ClientTimeout(total=10)
) as response:
if response.status == 200:
return {"success": True, "message": "自定义工具连接成功"}
else:
return {"success": False, "message": f"连接失败,状态码: {response.status}"}
except Exception as e:
return {"success": False, "message": f"自定义工具测试失败: {str(e)}"}
async def test_custom_tool(
self,
tool_id: str,
tenant_id: uuid.UUID,
method: str,
path: str,
parameters: Dict[str, Any]
) -> Dict[str, Any]:
"""测试自定义工具API调用"""
try:
config = self._get_tool_config(tool_id, tenant_id)
if not config or config.tool_type != ToolType.CUSTOM.value:
return {"success": False, "message": "工具不存在或不是自定义工具"}
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == config.id
).first()
if not custom_config or not custom_config.base_url:
return {"success": False, "message": "自定义工具配置不完整"}
# 构建完整URL
url = custom_config.base_url.rstrip('/') + '/' + path.lstrip('/')
# 构建请求头
headers = {"Content-Type": "application/json"}
# 添加认证头
if custom_config.auth_type != AuthType.NONE.value:
auth_config = custom_config.auth_config or {}
if custom_config.auth_type == AuthType.API_KEY.value:
key_name = auth_config.get("key_name", "X-API-Key")
api_key = auth_config.get("api_key")
if api_key:
headers[key_name] = api_key
elif custom_config.auth_type == AuthType.BEARER_TOKEN.value:
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif custom_config.auth_type == AuthType.BASIC_AUTH.value:
import base64
username = auth_config.get("username", "")
password = auth_config.get("password", "")
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
import aiohttp
async with aiohttp.ClientSession() as session:
# 根据方法发送请求
if method.upper() == "GET":
async with session.get(
url,
params=parameters,
headers=headers,
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
) as response:
result_data = await response.text()
return {
"success": True,
"message": "测试成功",
"status_code": response.status,
"response_data": result_data[:1000] # 限制返回数据长度
}
else:
async with session.request(
method.upper(),
url,
json=parameters,
headers=headers,
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
) as response:
result_data = await response.text()
return {
"success": True,
"message": "测试成功",
"status_code": response.status,
"response_data": result_data[:1000] # 限制返回数据长度
}
except Exception as e:
logger.error(f"测试自定义工具API失败: {tool_id}, 错误: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"}
async def _test_builtin_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试内置工具连接"""
try:
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return {"success": False, "message": "无法创建工具实例"}
# 检查工具是否有test_connection方法
if hasattr(tool_instance, 'test_connection'):
result = await tool_instance.test_connection()
return result
else:
# 检查是否需要配置
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
if builtin_config and builtin_config.requires_config:
# 检查必需参数是否已配置
if self._is_tool_configured(builtin_config):
return {"success": True, "message": "内置工具已正确配置"}
else:
return {"success": False, "message": "工具缺少必需配置参数"}
else:
return {"success": True, "message": "内置工具无需连接测试"}
except Exception as e:
logger.error(f"测试内置工具失败: {config.id}, 错误: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"}

View File

@@ -1,374 +0,0 @@
#!/usr/bin/env python3
"""
工具管理系统基础测试脚本
用于验证系统的基本功能是否正常
"""
import asyncio
import uuid
from datetime import datetime
# 测试导入
def test_imports():
"""测试模块导入"""
print("测试模块导入...")
try:
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
print("✓ 基础工具模块导入成功")
except ImportError as e:
print(f"✗ 基础工具模块导入失败: {e}")
return False
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
from app.core.tools.builtin.json_tool import JsonTool
print("✓ 内置工具模块导入成功")
except ImportError as e:
print(f"✗ 内置工具模块导入失败: {e}")
return False
try:
from app.core.tools.langchain_adapter import LangchainAdapter
print("✓ Langchain适配器导入成功")
except ImportError as e:
print(f"✗ Langchain适配器导入失败: {e}")
return False
try:
from app.models.tool_model import ToolConfig, ToolType, ToolStatus
print("✓ 工具模型导入成功")
except ImportError as e:
print(f"✗ 工具模型导入失败: {e}")
return False
try:
from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager
print("✓ 自定义工具模块导入成功")
except ImportError as e:
print(f"✗ 自定义工具模块导入失败: {e}")
return False
try:
from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager
print("✓ MCP工具模块导入成功")
except ImportError as e:
print(f"✗ MCP工具模块导入失败: {e}")
return False
return True
def test_tool_creation():
"""测试工具创建"""
print("\n测试工具创建...")
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
# 创建时间工具实例(全局工具)
tool_id = str(uuid.uuid4())
config = {
"parameters": {"timezone": "UTC"},
"tenant_id": None, # 全局工具
"version": "1.0.0",
"tags": ["time", "utility", "builtin"]
}
datetime_tool = DateTimeTool(tool_id, config)
# 验证工具属性
assert datetime_tool.name == "datetime_tool"
assert datetime_tool.tool_type.value == "builtin"
assert len(datetime_tool.parameters) > 0
print("✓ 时间工具创建成功(全局工具)")
return True
except Exception as e:
print(f"✗ 工具创建失败: {e}")
return False
async def test_tool_execution():
"""测试工具执行"""
print("\n测试工具执行...")
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
# 创建时间工具实例
tool_id = str(uuid.uuid4())
config = {
"parameters": {"timezone": "UTC"},
"tenant_id": None, # 全局工具
"version": "1.0.0"
}
datetime_tool = DateTimeTool(tool_id, config)
# 测试获取当前时间
result = await datetime_tool.safe_execute(operation="now")
assert result.success == True
assert "datetime" in result.data
assert result.execution_time > 0
print("✓ 工具执行成功")
print(f" 执行时间: {result.execution_time:.3f}")
print(f" 返回数据: {result.data}")
return True
except Exception as e:
print(f"✗ 工具执行失败: {e}")
return False
def test_langchain_adapter():
"""测试Langchain适配器"""
print("\n测试Langchain适配器...")
try:
from app.core.tools.builtin.json_tool import JsonTool
from app.core.tools.langchain_adapter import LangchainAdapter
# 创建JSON工具实例
tool_id = str(uuid.uuid4())
config = {
"parameters": {"indent": 2},
"tenant_id": None, # 全局工具
"version": "1.0.0"
}
json_tool = JsonTool(tool_id, config)
# 验证Langchain兼容性
is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool)
if not is_compatible:
print(f"✗ Langchain兼容性验证失败: {issues}")
return False
# 创建工具描述
description = LangchainAdapter.create_tool_description(json_tool)
assert "name" in description
assert "parameters" in description
assert description["langchain_compatible"] == True
print("✓ Langchain适配器测试成功")
return True
except Exception as e:
print(f"✗ Langchain适配器测试失败: {e}")
return False
def test_config_manager():
"""测试配置管理器"""
print("\n测试配置管理器...")
try:
from app.core.tools.config_manager import ConfigManager
# 创建配置管理器
config_manager = ConfigManager()
# 获取配置摘要
summary = config_manager.get_config_summary()
assert "config_dir" in summary
assert "total_configs" in summary
print("✓ 配置管理器测试成功")
print(f" 配置目录: {summary['config_dir']}")
print(f" 总配置数: {summary['total_configs']}")
return True
except Exception as e:
print(f"✗ 配置管理器测试失败: {e}")
return False
def test_schema_parser():
"""测试OpenAPI Schema解析器"""
print("\n测试OpenAPI Schema解析器...")
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
# 创建解析器
parser = OpenAPISchemaParser()
# 测试简单的OpenAPI schema
test_schema = {
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0",
"description": "测试API"
},
"paths": {
"/test": {
"get": {
"summary": "测试接口",
"operationId": "test_operation",
"responses": {
"200": {
"description": "成功"
}
}
}
}
}
}
# 验证schema
is_valid, error_msg = parser.validate_schema(test_schema)
assert is_valid, f"Schema验证失败: {error_msg}"
# 提取工具信息
tool_info = parser.extract_tool_info(test_schema)
assert tool_info["name"] == "Test API"
assert "test_operation" in tool_info["operations"]
print("✓ OpenAPI Schema解析器测试成功")
return True
except Exception as e:
print(f"✗ OpenAPI Schema解析器测试失败: {e}")
return False
def test_auth_manager():
"""测试认证管理器"""
print("\n测试认证管理器...")
try:
from app.core.tools.custom.auth_manager import AuthManager
from app.models.tool_model import AuthType
# 创建认证管理器
auth_manager = AuthManager()
# 测试API Key认证配置
api_key_config = {
"api_key": "test-key-123",
"key_name": "X-API-Key",
"location": "header"
}
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config)
assert is_valid, f"API Key配置验证失败: {error_msg}"
# 测试Bearer Token认证配置
bearer_config = {
"token": "bearer-token-123"
}
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config)
assert is_valid, f"Bearer Token配置验证失败: {error_msg}"
# 测试认证应用
url = "https://api.example.com/test"
headers = {}
params = {}
new_url, new_headers, new_params = auth_manager.apply_authentication(
AuthType.API_KEY, api_key_config, url, headers, params
)
assert "X-API-Key" in new_headers
assert new_headers["X-API-Key"] == "test-key-123"
print("✓ 认证管理器测试成功")
return True
except Exception as e:
print(f"✗ 认证管理器测试失败: {e}")
return False
def test_builtin_initializer():
"""测试内置工具初始化器"""
print("\n测试内置工具初始化器...")
try:
from app.core.tools.builtin_initializer import BuiltinToolInitializer
# 注意:这里不能真正初始化,因为需要数据库连接
# 只测试类的创建和基本方法
# 模拟数据库会话(实际使用中需要真实的数据库连接)
class MockDB:
def query(self, *args):
return self
def filter(self, *args):
return self
def first(self):
return None
def all(self):
return []
mock_db = MockDB()
initializer = BuiltinToolInitializer(mock_db)
# 测试获取内置工具状态(会返回空列表,因为没有真实数据)
status = initializer.get_builtin_tools_status()
assert isinstance(status, list)
print("✓ 内置工具初始化器测试成功")
return True
except Exception as e:
print(f"✗ 内置工具初始化器测试失败: {e}")
return False
async def main():
"""主测试函数"""
print("=" * 50)
print("工具管理系统基础测试")
print("=" * 50)
tests = [
("模块导入", test_imports),
("工具创建", test_tool_creation),
("工具执行", test_tool_execution),
("Langchain适配", test_langchain_adapter),
("配置管理", test_config_manager),
("Schema解析器", test_schema_parser),
("认证管理器", test_auth_manager),
("内置工具初始化器", test_builtin_initializer)
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
try:
if asyncio.iscoroutinefunction(test_func):
result = await test_func()
else:
result = test_func()
if result:
passed += 1
except Exception as e:
print(f"{test_name}测试异常: {e}")
print("\n" + "=" * 50)
print(f"测试结果: {passed}/{total} 通过")
if passed == total:
print("🎉 所有基础测试通过!工具管理系统基本功能正常。")
return True
else:
print("⚠️ 部分测试失败,请检查相关模块。")
return False
if __name__ == "__main__":
asyncio.run(main())