Merge #26 into develop from feature/20251219_xjn
feat(tool system): tool system development * feature/20251219_xjn: (25 commits) feat(apikey system): tool system development Merge #13 into develop from fix/stream-output [fix]document chunk QA [add] workflow support stream mode Merge #9 into develop from fix/memory_reflection Merge #18 into develop from fix/memory_reflection [add] migration script fix(workflow): fix run_workflow streaming issues fix(prompt-optimizer): switch to built-in system prompt feat(workflow): add conditional branch (If-Else) node perf(types): add Union type declaration for workflow nodes fix(expression-eval): fix variable extraction issue in Jinja2 templates docs(samples): add config example for If-Else node style(workflow): update condition edge comments for conditional nodes style(enums): correct enum class name spelling refactor(workflow): unify all enum classes in one file and restructure workflow... feat(workflow): add import for if-else node configuration [add] migration script Merge #19 into develop from fix/memory_reflection Merge #21 into develop from feature/emotion-engine Merge #13 into develop from fix/stream-output Merge #21 into develop from feature/emotion-engine [add] migration script Merge remote-tracking branch 'origin/develop' into develop feat(tool system): tool system development Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/26
This commit is contained in:
@@ -32,6 +32,8 @@ from . import (
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
tool_execution_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -66,4 +68,7 @@ manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(tool_execution_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
585
api/app/controllers/tool_controller.py
Normal file
585
api/app/controllers/tool_controller.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""工具管理API控制器"""
|
||||
import base64
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from langfuse.api.core import jsonable_encoder
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["工具管理"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""加密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
encrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
|
||||
else:
|
||||
encrypted_params[key] = value
|
||||
|
||||
return encrypted_params
|
||||
|
||||
|
||||
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
decrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
try:
|
||||
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
|
||||
except Exception as e:
|
||||
decrypted_params[key] = value
|
||||
else:
|
||||
decrypted_params[key] = value
|
||||
|
||||
return decrypted_params
|
||||
|
||||
|
||||
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
|
||||
"""更新工具状态并返回新状态"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN:
|
||||
if not tool_info or not tool_info.get('requires_config', False):
|
||||
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
|
||||
elif not builtin_config or not builtin_config.parameters:
|
||||
new_status = ToolStatus.INACTIVE.value
|
||||
else:
|
||||
# 检查是否有必要的API密钥
|
||||
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
|
||||
else: # 自定义和MCP工具
|
||||
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
|
||||
|
||||
# 更新数据库中的状态
|
||||
if tool_config.status != new_status:
|
||||
tool_config.status = new_status
|
||||
|
||||
return new_status
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""工具列表响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
category: str
|
||||
version: str = "1.0.0"
|
||||
status: str # active inactive error loading
|
||||
requires_config: bool = False
|
||||
# is_configured: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class BuiltinToolConfigRequest(BaseModel):
|
||||
"""内置工具配置请求"""
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
|
||||
|
||||
class CustomToolCreateRequest(BaseModel):
|
||||
"""自定义工具创建请求体模型,包含参数校验规则"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
|
||||
description: str = Field(None, description="工具描述")
|
||||
base_url: str = Field(None, description="工具基础URL")
|
||||
schema_url: str = Field(None, description="工具Schema URL")
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选")
|
||||
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
|
||||
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
|
||||
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30")
|
||||
|
||||
# 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段
|
||||
@field_validator("auth_config")
|
||||
def validate_auth_config(cls, v, values):
|
||||
auth_type = values.data.get("auth_type")
|
||||
if auth_type == "api_key" and (not v or "api_key" not in v):
|
||||
raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段")
|
||||
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
|
||||
raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段")
|
||||
return v
|
||||
|
||||
class MCPToolCreateRequest(BaseModel):
|
||||
"""MCP工具创建请求体模型,适配MCP业务特性"""
|
||||
# 基础必填字段(带长度/格式校验)
|
||||
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
|
||||
description: str = Field(None, description="MCP工具描述")
|
||||
# MCP核心字段:服务端URL(强制HTTP/HTTPS格式)
|
||||
server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议")
|
||||
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
|
||||
connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典")
|
||||
|
||||
@field_validator("connection_config")
|
||||
def validate_connection_config(cls, v):
|
||||
# 示例1:若包含timeout,必须是1-300的整数
|
||||
if "timeout" in v:
|
||||
timeout = v["timeout"]
|
||||
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
|
||||
raise ValueError("connection_config.timeout必须是1-300的整数")
|
||||
return v
|
||||
|
||||
# @field_validator("server_url")
|
||||
# def validate_server_url_protocol(cls, v):
|
||||
# if v.scheme != "https":
|
||||
# raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)")
|
||||
# return v
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
@router.get("", response_model=List[ToolListResponse])
|
||||
async def list_tools(
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具列表(包含内置工具、自定义工具和MCP工具)"""
|
||||
try:
|
||||
# 初始化内置工具(如果需要)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.ensure_builtin_tools_initialized(
|
||||
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
|
||||
)
|
||||
|
||||
response_tools = []
|
||||
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
)
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type)
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
|
||||
tools = query.all()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
|
||||
for tool_config in tools:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
status = _update_tool_status(tool_config, builtin_config, tool_info)
|
||||
else:
|
||||
status = _update_tool_status(tool_config)
|
||||
|
||||
response_tools.append(ToolListResponse(
|
||||
id=str(tool_config.id),
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
tool_type=tool_config.tool_type,
|
||||
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
|
||||
version="1.0.0",
|
||||
status=status,
|
||||
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
|
||||
))
|
||||
|
||||
return response_tools
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}")
|
||||
async def get_builtin_tool_detail(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具详情"""
|
||||
try:
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id
|
||||
).first()
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
|
||||
is_configured = False
|
||||
config_parameters = {}
|
||||
|
||||
if builtin_config and builtin_config.parameters:
|
||||
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
# 不返回敏感信息,只返回非敏感配置
|
||||
config_parameters = {k: v for k, v in builtin_config.parameters.items()
|
||||
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
|
||||
|
||||
return {
|
||||
"id": tool_config.id,
|
||||
"name": tool_config.name,
|
||||
"description": tool_config.description,
|
||||
"category": tool_info['category'],
|
||||
"status": tool_config.tool_type,
|
||||
"requires_config": tool_info['requires_config'],
|
||||
"is_configured": is_configured,
|
||||
"config_parameters": config_parameters
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/builtin/{tool_id}/configure")
|
||||
async def configure_builtin_tool(
|
||||
tool_id: str,
|
||||
request: BuiltinToolConfigRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""配置内置工具参数(租户级别)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 获取全局工具信息
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools_config = config_manager.load_builtin_tools_config()
|
||||
tool_info = None
|
||||
for tool_key, info in builtin_tools_config.items():
|
||||
if info['tool_class'] == builtin_config.tool_class:
|
||||
tool_info = info
|
||||
break
|
||||
|
||||
if not tool_info:
|
||||
raise HTTPException(status_code=404, detail="工具信息不存在")
|
||||
|
||||
# 加密敏感参数
|
||||
encrypted_params = _encrypt_sensitive_params(request.parameters)
|
||||
|
||||
# 更新配置
|
||||
builtin_config.parameters = encrypted_params
|
||||
|
||||
# 更新状态
|
||||
_update_tool_status(tool_config, builtin_config, tool_info)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool_config.name} 配置成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"配置内置工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}/config")
|
||||
async def get_builtin_tool_config(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具配置(用于使用)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 解密参数
|
||||
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
|
||||
|
||||
return {
|
||||
"tool_id": tool_id,
|
||||
"tool_class": builtin_config.tool_class,
|
||||
"name": tool_config.name,
|
||||
"parameters": decrypted_params,
|
||||
"status": tool_config.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/custom")
|
||||
async def create_custom_tool(
|
||||
request: CustomToolCreateRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建自定义工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "custom"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.CUSTOM,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建CustomToolConfig记录
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
base_url=request.base_url,
|
||||
schema_url=request.schema_url,
|
||||
schema_content=request.schema_content,
|
||||
auth_type=request.auth_type,
|
||||
auth_config=request.auth_config,
|
||||
timeout=request.timeout
|
||||
)
|
||||
db.add(custom_config)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"自定义工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建自定义工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/mcp")
|
||||
async def create_mcp_tool(
|
||||
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建MCP工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "mcp"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
try:
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.MCP,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建MCPToolConfig记录
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=request.server_url,
|
||||
connection_config=request.connection_config
|
||||
)
|
||||
db.add(mcp_config)
|
||||
|
||||
db.commit()
|
||||
except SQLAlchemyError as db_e:
|
||||
db.rollback()
|
||||
logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}",
|
||||
exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},"
|
||||
f"工具名:{request.name}):{str(db_e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"MCP工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建MCP工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许删除")
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 删除成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{tool_id}")
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许修改")
|
||||
|
||||
if config_data is not None:
|
||||
tool.config_data = config_data
|
||||
# 更新状态
|
||||
_update_tool_status(tool)
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 更新成功",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{tool_id}/toggle")
|
||||
async def toggle_tool_status(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换工具活跃/非活跃状态"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 在active和inactive之间切换
|
||||
if tool.status == ToolStatus.ACTIVE.value:
|
||||
tool.status = ToolStatus.INACTIVE.value
|
||||
elif tool.status == ToolStatus.INACTIVE.value:
|
||||
tool.status = ToolStatus.ACTIVE.value
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换工具状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
430
api/app/controllers/tool_execution_controller.py
Normal file
430
api/app/controllers/tool_execution_controller.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""工具执行API控制器"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
|
||||
from app.core.tools.builtin import *
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""工具执行请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
|
||||
|
||||
|
||||
class BatchExecutionRequest(BaseModel):
|
||||
"""批量执行请求"""
|
||||
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
|
||||
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""工具执行响应"""
|
||||
success: bool
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error_code: Optional[str] = None
|
||||
execution_time: float
|
||||
token_usage: Optional[Dict[str, int]] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChainStepRequest(BaseModel):
|
||||
"""链步骤请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
condition: Optional[str] = Field(None, description="执行条件")
|
||||
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
|
||||
error_handling: str = Field("stop", description="错误处理策略")
|
||||
|
||||
|
||||
class ChainExecutionRequest(BaseModel):
|
||||
"""链执行请求"""
|
||||
name: str = Field(..., description="链名称")
|
||||
description: str = Field("", description="链描述")
|
||||
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
|
||||
execution_mode: str = Field("sequential", description="执行模式")
|
||||
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
|
||||
global_timeout: Optional[float] = Field(None, description="全局超时")
|
||||
|
||||
|
||||
class ExecutionHistoryResponse(BaseModel):
|
||||
"""执行历史响应"""
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
status: str
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
execution_time: Optional[float]
|
||||
user_id: Optional[str]
|
||||
workspace_id: Optional[str]
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Any]
|
||||
error_message: Optional[str]
|
||||
token_usage: Optional[Dict[str, int]]
|
||||
|
||||
|
||||
class ToolConnectionTestResponse(BaseModel):
|
||||
"""工具连接测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== 依赖注入 ====================
|
||||
|
||||
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
|
||||
"""获取工具注册表"""
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
# 注册内置工具类
|
||||
registry.register_tool_class(DateTimeTool)
|
||||
registry.register_tool_class(JsonTool)
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
db: Session = Depends(get_db),
|
||||
registry: ToolRegistry = Depends(get_tool_registry)
|
||||
) -> ToolExecutor:
|
||||
"""获取工具执行器"""
|
||||
return ToolExecutor(db, registry)
|
||||
|
||||
|
||||
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
|
||||
"""获取链管理器"""
|
||||
return ChainManager(executor)
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
|
||||
@router.post("/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""执行单个工具"""
|
||||
try:
|
||||
# 生成执行ID
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=request.tool_id,
|
||||
parameters=request.parameters,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
execution_id=execution_id,
|
||||
timeout=request.timeout,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_id,
|
||||
tool_id=request.tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[ToolExecutionResponse])
|
||||
async def execute_tools_batch(
|
||||
request: BatchExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""批量执行工具"""
|
||||
try:
|
||||
# 准备执行配置
|
||||
execution_configs = []
|
||||
execution_ids = []
|
||||
|
||||
for exec_request in request.executions:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
execution_ids.append(execution_id)
|
||||
|
||||
execution_configs.append({
|
||||
"tool_id": exec_request.tool_id,
|
||||
"parameters": exec_request.parameters,
|
||||
"user_id": current_user.id,
|
||||
"workspace_id": current_user.current_workspace_id,
|
||||
"execution_id": execution_id,
|
||||
"timeout": exec_request.timeout,
|
||||
"metadata": exec_request.metadata
|
||||
})
|
||||
|
||||
# 批量执行
|
||||
results = await executor.execute_tools_batch(
|
||||
execution_configs,
|
||||
max_concurrency=request.max_concurrency
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for i, result in enumerate(results):
|
||||
responses.append(ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_ids[i],
|
||||
tool_id=request.executions[i].tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/chain", response_model=Dict[str, Any])
|
||||
async def execute_tool_chain(
|
||||
request: ChainExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""执行工具链"""
|
||||
try:
|
||||
# 转换步骤格式
|
||||
steps = []
|
||||
for step_request in request.steps:
|
||||
step = ChainStep(
|
||||
tool_id=step_request.tool_id,
|
||||
parameters=step_request.parameters,
|
||||
condition=step_request.condition,
|
||||
output_mapping=step_request.output_mapping,
|
||||
error_handling=step_request.error_handling
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# 创建链定义
|
||||
chain_definition = ChainDefinition(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
steps=steps,
|
||||
execution_mode=ChainExecutionMode(request.execution_mode),
|
||||
global_timeout=request.global_timeout
|
||||
)
|
||||
|
||||
# 注册并执行链
|
||||
chain_manager.register_chain(chain_definition)
|
||||
|
||||
result = await chain_manager.execute_chain(
|
||||
chain_name=request.name,
|
||||
initial_variables=request.initial_variables
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_executions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取正在运行的执行"""
|
||||
try:
|
||||
running_executions = executor.get_running_executions()
|
||||
|
||||
# 过滤当前工作空间的执行
|
||||
workspace_executions = [
|
||||
exec_info for exec_info in running_executions
|
||||
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
|
||||
]
|
||||
|
||||
return workspace_executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
|
||||
async def cancel_execution(
|
||||
execution_id: str = Path(..., description="执行ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""取消工具执行"""
|
||||
try:
|
||||
success = await executor.cancel_execution(execution_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "执行已取消"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="执行不存在或已完成")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ExecutionHistoryResponse])
|
||||
async def get_execution_history(
|
||||
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行历史"""
|
||||
try:
|
||||
history = executor.get_execution_history(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for record in history:
|
||||
responses.append(ExecutionHistoryResponse(
|
||||
execution_id=record["execution_id"],
|
||||
tool_id=record["tool_id"],
|
||||
status=record["status"],
|
||||
started_at=record["started_at"],
|
||||
completed_at=record["completed_at"],
|
||||
execution_time=record["execution_time"],
|
||||
user_id=record["user_id"],
|
||||
workspace_id=record["workspace_id"],
|
||||
input_data=record["input_data"],
|
||||
output_data=record["output_data"],
|
||||
error_message=record["error_message"],
|
||||
token_usage=record["token_usage"]
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=Dict[str, Any])
|
||||
async def get_execution_statistics(
|
||||
days: int = Query(7, ge=1, le=90, description="统计天数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行统计"""
|
||||
try:
|
||||
stats = executor.get_execution_statistics(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""获取正在运行的工具链"""
|
||||
try:
|
||||
running_chains = chain_manager.get_running_chains()
|
||||
return running_chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中工具链失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains", response_model=List[Dict[str, Any]])
|
||||
async def list_tool_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""列出工具链"""
|
||||
try:
|
||||
chains = chain_manager.list_chains()
|
||||
return chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具链列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
|
||||
async def test_tool_connection(
|
||||
tool_id: str = Path(..., description="工具ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
result = await executor.test_tool_connection(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
|
||||
return ToolConnectionTestResponse(
|
||||
success=result.get("success", False),
|
||||
message=result.get("message", ""),
|
||||
error=result.get("error"),
|
||||
details=result.get("details")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
|
||||
return ToolConnectionTestResponse(
|
||||
success=False,
|
||||
message="连接测试失败",
|
||||
error=str(e)
|
||||
)
|
||||
@@ -37,9 +37,10 @@ def require_api_key(
|
||||
@require_api_key(scopes=["app"])
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
api_key_auth: ApiKeyAuth = Depends(),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str
|
||||
message: str = Query(..., description="聊天消息内容")
|
||||
):
|
||||
# api_key_auth 包含验证后的API Key 信息
|
||||
pass
|
||||
|
||||
@@ -157,6 +157,12 @@ class Settings:
|
||||
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
|
||||
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
37
api/app/core/tools/__init__.py
Normal file
37
api/app/core/tools/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""工具管理核心模块"""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolParameter
|
||||
from .registry import ToolRegistry
|
||||
from .executor import ToolExecutor
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
from .config_manager import ConfigManager
|
||||
from .chain_manager import ChainManager
|
||||
|
||||
# 可选导入,避免导入错误
|
||||
try:
|
||||
from .custom.base import CustomTool
|
||||
except ImportError:
|
||||
CustomTool = None
|
||||
|
||||
try:
|
||||
from .mcp.base import MCPTool
|
||||
except ImportError:
|
||||
MCPTool = None
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolParameter",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
"LangchainAdapter",
|
||||
"ConfigManager",
|
||||
"ChainManager"
|
||||
]
|
||||
|
||||
# 只有在成功导入时才添加到__all__
|
||||
if CustomTool:
|
||||
__all__.append("CustomTool")
|
||||
|
||||
if MCPTool:
|
||||
__all__.append("MCPTool")
|
||||
302
api/app/core/tools/base.py
Normal file
302
api/app/core/tools/base.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""工具基础接口定义"""
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
from app.models.tool_model import ToolType, ToolStatus
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型枚举"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""工具参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: ParameterType = Field(..., description="参数类型")
|
||||
description: str = Field("", description="参数描述")
|
||||
required: bool = Field(False, description="是否必需")
|
||||
default: Any = Field(None, description="默认值")
|
||||
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
||||
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
|
||||
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
|
||||
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果"""
|
||||
success: bool = Field(..., description="执行是否成功")
|
||||
data: Any = Field(None, description="返回数据")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
error_code: Optional[str] = Field(None, description="错误代码")
|
||||
execution_time: float = Field(..., description="执行时间(秒)")
|
||||
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
data: Any,
|
||||
execution_time: float,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建成功结果"""
|
||||
return cls(
|
||||
success=True,
|
||||
data=data,
|
||||
execution_time=execution_time,
|
||||
token_usage=token_usage,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error_result(
|
||||
cls,
|
||||
error: str,
|
||||
execution_time: float,
|
||||
error_code: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建错误结果"""
|
||||
return cls(
|
||||
success=False,
|
||||
error=error,
|
||||
error_code=error_code,
|
||||
execution_time=execution_time,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""工具信息"""
|
||||
id: str = Field(..., description="工具ID")
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
tool_type: ToolType = Field(..., description="工具类型")
|
||||
version: str = Field("1.0.0", description="工具版本")
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基础抽象类"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
self.tool_id = tool_id
|
||||
self.config = config
|
||||
self._status = ToolStatus.ACTIVE
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""工具版本"""
|
||||
return self.config.get("version", "1.0.0")
|
||||
|
||||
@property
|
||||
def status(self) -> ToolStatus:
|
||||
"""工具状态"""
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: ToolStatus):
|
||||
"""设置工具状态"""
|
||||
self._status = value
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def tags(self) -> List[str]:
|
||||
"""工具标签"""
|
||||
return self.config.get("tags", [])
|
||||
|
||||
def get_info(self) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
return ToolInfo(
|
||||
id=self.tool_id,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
tool_type=self.tool_type,
|
||||
version=self.version,
|
||||
parameters=self.parameters,
|
||||
status=self.status,
|
||||
tags=self.tags,
|
||||
tenant_id=self.config.get("tenant_id")
|
||||
)
|
||||
|
||||
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""验证参数
|
||||
|
||||
Args:
|
||||
parameters: 输入参数
|
||||
|
||||
Returns:
|
||||
验证错误字典,空字典表示验证通过
|
||||
"""
|
||||
errors = {}
|
||||
param_definitions = {p.name: p for p in self.parameters}
|
||||
|
||||
# 检查必需参数
|
||||
for param_def in self.parameters:
|
||||
if param_def.required and param_def.name not in parameters:
|
||||
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
|
||||
|
||||
# 检查参数类型和约束
|
||||
for param_name, param_value in parameters.items():
|
||||
if param_name not in param_definitions:
|
||||
continue
|
||||
|
||||
param_def = param_definitions[param_name]
|
||||
|
||||
# 类型检查
|
||||
if not self._validate_parameter_type(param_value, param_def):
|
||||
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
|
||||
|
||||
# 约束检查
|
||||
constraint_error = self._validate_parameter_constraints(param_value, param_def)
|
||||
if constraint_error:
|
||||
errors[param_name] = constraint_error
|
||||
|
||||
return errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
|
||||
"""验证参数类型"""
|
||||
if value is None:
|
||||
return not param_def.required
|
||||
|
||||
type_mapping = {
|
||||
ParameterType.STRING: str,
|
||||
ParameterType.INTEGER: int,
|
||||
ParameterType.NUMBER: (int, float),
|
||||
ParameterType.BOOLEAN: bool,
|
||||
ParameterType.ARRAY: list,
|
||||
ParameterType.OBJECT: dict
|
||||
}
|
||||
|
||||
expected_type = type_mapping.get(param_def.type)
|
||||
if expected_type:
|
||||
return isinstance(value, expected_type)
|
||||
|
||||
return True
|
||||
|
||||
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
|
||||
"""验证参数约束"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# 枚举值检查
|
||||
if param_def.enum and value not in param_def.enum:
|
||||
return f"Value must be one of {param_def.enum}"
|
||||
|
||||
# 数值范围检查
|
||||
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
|
||||
if param_def.minimum is not None and value < param_def.minimum:
|
||||
return f"Value must be >= {param_def.minimum}"
|
||||
if param_def.maximum is not None and value > param_def.maximum:
|
||||
return f"Value must be <= {param_def.maximum}"
|
||||
|
||||
# 字符串模式检查
|
||||
if param_def.type == ParameterType.STRING and param_def.pattern:
|
||||
import re
|
||||
if not re.match(param_def.pattern, str(value)):
|
||||
return f"Value must match pattern: {param_def.pattern}"
|
||||
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
async def safe_execute(self, **kwargs) -> ToolResult:
|
||||
"""安全执行工具(包含参数验证和异常处理)
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 参数验证
|
||||
validation_errors = self.validate_parameters(kwargs)
|
||||
if validation_errors:
|
||||
execution_time = time.time() - start_time
|
||||
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
|
||||
return ToolResult.error_result(
|
||||
error=f"Parameter validation failed: {error_msg}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute(**kwargs)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
17
api/app/core/tools/builtin/__init__.py
Normal file
17
api/app/core/tools/builtin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""内置工具模块"""
|
||||
|
||||
from .base import BuiltinTool
|
||||
from .datetime_tool import DateTimeTool
|
||||
from .json_tool import JsonTool
|
||||
from .baidu_search_tool import BaiduSearchTool
|
||||
from .mineru_tool import MinerUTool
|
||||
from .textin_tool import TextInTool
|
||||
|
||||
__all__ = [
|
||||
"BuiltinTool",
|
||||
"DateTimeTool",
|
||||
"JsonTool",
|
||||
"BaiduSearchTool",
|
||||
"MinerUTool",
|
||||
"TextInTool"
|
||||
]
|
||||
334
api/app/core/tools/builtin/baidu_search_tool.py
Normal file
334
api/app/core/tools/builtin/baidu_search_tool.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""百度搜索工具 - 搜索引擎服务"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class BaiduSearchTool(BuiltinTool):
|
||||
"""百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "baidu_search_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["api_key"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索关键词",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="search_type",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索类型",
|
||||
required=False,
|
||||
default="web",
|
||||
enum=["web", "news", "image", "video"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_size",
|
||||
type=ParameterType.INTEGER,
|
||||
description="每页结果数",
|
||||
required=False,
|
||||
default=10,
|
||||
minimum=1,
|
||||
maximum=50
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_num",
|
||||
type=ParameterType.INTEGER,
|
||||
description="页码(从1开始)",
|
||||
required=False,
|
||||
default=1,
|
||||
minimum=1,
|
||||
maximum=10
|
||||
),
|
||||
ToolParameter(
|
||||
name="safe_search",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否启用安全搜索",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="region",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索地区",
|
||||
required=False,
|
||||
default="cn",
|
||||
enum=["cn", "hk", "tw", "us", "jp", "kr"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="time_filter",
|
||||
type=ParameterType.STRING,
|
||||
description="时间过滤",
|
||||
required=False,
|
||||
enum=["all", "day", "week", "month", "year"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行百度搜索"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
query = kwargs.get("query")
|
||||
search_type = kwargs.get("search_type", "web")
|
||||
page_size = kwargs.get("page_size", 10)
|
||||
page_num = kwargs.get("page_num", 1)
|
||||
safe_search = kwargs.get("safe_search", True)
|
||||
region = kwargs.get("region", "cn")
|
||||
time_filter = kwargs.get("time_filter")
|
||||
|
||||
if not query:
|
||||
raise ValueError("query 参数是必需的")
|
||||
|
||||
# 根据搜索类型调用不同的API
|
||||
if search_type == "web":
|
||||
result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter)
|
||||
elif search_type == "news":
|
||||
result = await self._news_search(query, page_size, page_num, region, time_filter)
|
||||
elif search_type == "image":
|
||||
result = await self._image_search(query, page_size, page_num, safe_search)
|
||||
elif search_type == "video":
|
||||
result = await self._video_search(query, page_size, page_num, safe_search)
|
||||
else:
|
||||
raise ValueError(f"不支持的搜索类型: {search_type}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="BAIDU_SEARCH_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _web_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]:
|
||||
"""网页搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
|
||||
if time_filter in time_map:
|
||||
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
|
||||
payload["search_recency_filter"] = time_filter
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "web",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _news_search(self, query: str, page_size: int, page_num: int,
|
||||
region: str, time_filter: str = None) -> Dict[str, Any]:
|
||||
"""新闻搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
|
||||
if time_filter in time_map:
|
||||
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
|
||||
payload["search_recency_filter"] = time_filter
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "new",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _image_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool) -> Dict[str, Any]:
|
||||
"""图片搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "image",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _video_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool) -> Dict[str, Any]:
|
||||
"""视频搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "video",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用百度AI搜索API"""
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("百度搜索API密钥未配置")
|
||||
|
||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
raise Exception(f"HTTP错误: {response.status}")
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API密钥未配置"
|
||||
}
|
||||
|
||||
# 发送测试请求验证API key是否有效
|
||||
test_payload = {
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": 1}]
|
||||
}
|
||||
|
||||
try:
|
||||
await self._call_baidu_ai_search_api(test_payload)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接测试成功",
|
||||
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API连接失败: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
118
api/app/core/tools/builtin/base.py
Normal file
118
api/app/core/tools/builtin/base.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""内置工具基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
|
||||
|
||||
|
||||
class BuiltinTool(BaseTool, ABC):
|
||||
"""内置工具基类"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化内置工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.parameters_config = config.get("parameters", {})
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""工具名称 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""工具描述 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行工具 - 子类必须实现
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""检查工具是否已正确配置"""
|
||||
required_params = self.get_required_config_parameters()
|
||||
for param in required_params:
|
||||
if not self.parameters_config.get(param):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
"""获取必需的配置参数列表
|
||||
|
||||
Returns:
|
||||
必需配置参数名称列表
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_config_parameter(self, name: str, default: Any = None) -> Any:
|
||||
"""获取配置参数值
|
||||
|
||||
Args:
|
||||
name: 参数名称
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
参数值
|
||||
"""
|
||||
return self.parameters_config.get(name, default)
|
||||
|
||||
def validate_configuration(self) -> tuple[bool, str]:
|
||||
"""验证工具配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not self.is_configured:
|
||||
required_params = self.get_required_config_parameters()
|
||||
missing_params = [p for p in required_params if not self.parameters_config.get(p)]
|
||||
return False, f"缺少必需的配置参数: {', '.join(missing_params)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
async def safe_execute(self, **kwargs) -> ToolResult:
|
||||
"""安全执行工具(包含配置验证)
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
# 首先验证配置
|
||||
is_valid, error_msg = self.validate_configuration()
|
||||
if not is_valid:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具配置无效: {error_msg}",
|
||||
error_code="CONFIGURATION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 调用父类的安全执行
|
||||
return await super().safe_execute(**kwargs)
|
||||
307
api/app/core/tools/builtin/datetime_tool.py
Normal file
307
api/app/core/tools/builtin/datetime_tool.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""时间工具 - 日期时间处理"""
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List
|
||||
import pytz
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class DateTimeTool(BuiltinTool):
|
||||
"""时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "datetime_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
),
|
||||
ToolParameter(
|
||||
name="calculation",
|
||||
type=ParameterType.STRING,
|
||||
description="时间计算表达式(如:+1d, -2h, +30m)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行时间工具操作"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
|
||||
if operation == "now":
|
||||
result = self._get_current_time(kwargs)
|
||||
elif operation == "format":
|
||||
result = self._format_datetime(kwargs)
|
||||
elif operation == "convert_timezone":
|
||||
result = self._convert_timezone(kwargs)
|
||||
elif operation == "timestamp_to_datetime":
|
||||
result = self._timestamp_to_datetime(kwargs)
|
||||
elif operation == "datetime_to_timestamp":
|
||||
result = self._datetime_to_timestamp(kwargs)
|
||||
elif operation == "calculate":
|
||||
result = self._calculate_datetime(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="DATETIME_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _get_current_time(self, kwargs) -> dict:
|
||||
"""获取当前时间"""
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
now = datetime.now(tz)
|
||||
|
||||
return {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": now.isoformat()
|
||||
}
|
||||
|
||||
def _format_datetime(self, kwargs) -> dict:
|
||||
"""格式化时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"formatted": dt.strftime(output_format),
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _convert_timezone(self, kwargs) -> dict:
|
||||
"""时区转换"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
from_timezone = kwargs.get("from_timezone", "UTC")
|
||||
to_timezone = kwargs.get("to_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置源时区
|
||||
if from_timezone == "UTC":
|
||||
from_tz = pytz.UTC
|
||||
else:
|
||||
from_tz = pytz.timezone(from_timezone)
|
||||
|
||||
# 设置目标时区
|
||||
if to_timezone == "UTC":
|
||||
to_tz = pytz.UTC
|
||||
else:
|
||||
to_tz = pytz.timezone(to_timezone)
|
||||
|
||||
# 本地化时间并转换时区
|
||||
if dt.tzinfo is None:
|
||||
dt = from_tz.localize(dt)
|
||||
|
||||
converted_dt = dt.astimezone(to_tz)
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"original_timezone": from_timezone,
|
||||
"converted": converted_dt.strftime(output_format),
|
||||
"converted_timezone": to_timezone,
|
||||
"timestamp": int(converted_dt.timestamp())
|
||||
}
|
||||
|
||||
def _timestamp_to_datetime(self, kwargs) -> dict:
|
||||
"""时间戳转日期时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 转换时间戳
|
||||
timestamp = float(input_value)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp, tz)
|
||||
|
||||
return {
|
||||
"timestamp": timestamp,
|
||||
"datetime": dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _datetime_to_timestamp(self, kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("from_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
# 本地化时间
|
||||
if dt.tzinfo is None:
|
||||
dt = tz.localize(dt)
|
||||
|
||||
return {
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
"""时间计算"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
calculation = kwargs.get("calculation")
|
||||
timezone_str = kwargs.get("from_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
if not calculation:
|
||||
raise ValueError("calculation 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
if dt.tzinfo is None:
|
||||
dt = tz.localize(dt)
|
||||
|
||||
# 解析计算表达式
|
||||
delta = self._parse_time_delta(calculation)
|
||||
calculated_dt = dt + delta
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"calculation": calculation,
|
||||
"result": calculated_dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(calculated_dt.timestamp())
|
||||
}
|
||||
|
||||
def _parse_time_delta(self, calculation: str) -> timedelta:
|
||||
"""解析时间计算表达式"""
|
||||
import re
|
||||
|
||||
# 支持的单位:d(天), h(小时), m(分钟), s(秒)
|
||||
pattern = r'([+-]?\d+)([dhms])'
|
||||
matches = re.findall(pattern, calculation.lower())
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"无效的时间计算表达式: {calculation}")
|
||||
|
||||
total_delta = timedelta()
|
||||
|
||||
for value_str, unit in matches:
|
||||
value = int(value_str)
|
||||
|
||||
if unit == 'd':
|
||||
total_delta += timedelta(days=value)
|
||||
elif unit == 'h':
|
||||
total_delta += timedelta(hours=value)
|
||||
elif unit == 'm':
|
||||
total_delta += timedelta(minutes=value)
|
||||
elif unit == 's':
|
||||
total_delta += timedelta(seconds=value)
|
||||
|
||||
return total_delta
|
||||
430
api/app/core/tools/builtin/json_tool.py
Normal file
430
api/app/core/tools/builtin/json_tool.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""JSON转换工具 - 数据格式转换"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Any, Dict
|
||||
import yaml
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class JsonTool(BuiltinTool):
|
||||
"""JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "json_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "JSON转换工具 - 数据格式转换:JSON格式化、JSON压缩、JSON验证、格式转换"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="indent",
|
||||
type=ParameterType.INTEGER,
|
||||
description="JSON格式化缩进空格数",
|
||||
required=False,
|
||||
default=2,
|
||||
minimum=0,
|
||||
maximum=8
|
||||
),
|
||||
ToolParameter(
|
||||
name="ensure_ascii",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否确保ASCII编码",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="sort_keys",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否对键进行排序",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="merge_data",
|
||||
type=ParameterType.STRING,
|
||||
description="要合并的JSON数据(用于merge操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(用于extract操作,如:$.user.name)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行JSON工具操作"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
input_data = kwargs.get("input_data")
|
||||
|
||||
if not input_data:
|
||||
raise ValueError("input_data 参数是必需的")
|
||||
|
||||
if operation == "format":
|
||||
result = self._format_json(input_data, kwargs)
|
||||
elif operation == "minify":
|
||||
result = self._minify_json(input_data)
|
||||
elif operation == "validate":
|
||||
result = self._validate_json(input_data)
|
||||
elif operation == "convert":
|
||||
result = self._convert_json(input_data)
|
||||
elif operation == "to_yaml":
|
||||
result = self._json_to_yaml(input_data)
|
||||
elif operation == "from_yaml":
|
||||
result = self._yaml_to_json(input_data, kwargs)
|
||||
elif operation == "to_xml":
|
||||
result = self._json_to_xml(input_data)
|
||||
elif operation == "from_xml":
|
||||
result = self._xml_to_json(input_data, kwargs)
|
||||
elif operation == "merge":
|
||||
result = self._merge_json(input_data, kwargs)
|
||||
elif operation == "extract":
|
||||
result = self._extract_json_path(input_data, kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="JSON_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""格式化JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
sort_keys = kwargs.get("sort_keys", False)
|
||||
|
||||
# 解析JSON
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 格式化输出
|
||||
formatted = json.dumps(
|
||||
data,
|
||||
indent=indent,
|
||||
ensure_ascii=ensure_ascii,
|
||||
sort_keys=sort_keys,
|
||||
separators=(',', ': ')
|
||||
)
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
"formatted_size": len(formatted),
|
||||
"formatted_json": formatted,
|
||||
"is_valid": True,
|
||||
"settings": {
|
||||
"indent": indent,
|
||||
"ensure_ascii": ensure_ascii,
|
||||
"sort_keys": sort_keys
|
||||
}
|
||||
}
|
||||
|
||||
def _minify_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""压缩JSON"""
|
||||
# 解析并压缩
|
||||
data = json.loads(input_data)
|
||||
minified = json.dumps(data, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
"minified_size": len(minified),
|
||||
"compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2),
|
||||
"minified_json": minified,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _validate_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""验证JSON"""
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 统计信息
|
||||
stats = self._analyze_json_structure(data)
|
||||
|
||||
return {
|
||||
"is_valid": True,
|
||||
"error": None,
|
||||
"size": len(input_data),
|
||||
"structure": stats
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"is_valid": False,
|
||||
"error": str(e),
|
||||
"error_line": getattr(e, 'lineno', None),
|
||||
"error_column": getattr(e, 'colno', None),
|
||||
"size": len(input_data)
|
||||
}
|
||||
|
||||
def _convert_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转义"""
|
||||
data = json.loads(input_data)
|
||||
converted = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"converted_json": converted,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转YAML"""
|
||||
data = json.loads(input_data)
|
||||
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
|
||||
return {
|
||||
"original_format": "json",
|
||||
"target_format": "yaml",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(yaml_output),
|
||||
"converted_data": yaml_output
|
||||
}
|
||||
|
||||
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""YAML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
|
||||
data = yaml.safe_load(input_data)
|
||||
json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
||||
|
||||
return {
|
||||
"original_format": "yaml",
|
||||
"target_format": "json",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转XML"""
|
||||
data = json.loads(input_data)
|
||||
|
||||
def dict_to_xml(data, root_name="root"):
|
||||
"""递归转换字典为XML"""
|
||||
if isinstance(data, dict):
|
||||
if len(data) == 1 and not root_name == "root":
|
||||
# 如果字典只有一个键,使用该键作为根元素
|
||||
key, value = next(iter(data.items()))
|
||||
return dict_to_xml(value, key)
|
||||
|
||||
root = ET.Element(root_name)
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
child = dict_to_xml(value, key)
|
||||
root.append(child)
|
||||
else:
|
||||
child = ET.SubElement(root, key)
|
||||
child.text = str(value)
|
||||
return root
|
||||
|
||||
elif isinstance(data, list):
|
||||
root = ET.Element(root_name)
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, (dict, list)):
|
||||
child = dict_to_xml(item, f"item_{i}")
|
||||
root.append(child)
|
||||
else:
|
||||
child = ET.SubElement(root, f"item_{i}")
|
||||
child.text = str(item)
|
||||
return root
|
||||
|
||||
else:
|
||||
root = ET.Element(root_name)
|
||||
root.text = str(data)
|
||||
return root
|
||||
|
||||
xml_element = dict_to_xml(data)
|
||||
xml_string = ET.tostring(xml_element, encoding='unicode')
|
||||
|
||||
# 格式化XML
|
||||
dom = minidom.parseString(xml_string)
|
||||
formatted_xml = dom.toprettyxml(indent=" ")
|
||||
|
||||
# 移除空行
|
||||
formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()])
|
||||
|
||||
return {
|
||||
"original_format": "json",
|
||||
"target_format": "xml",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(formatted_xml),
|
||||
"converted_data": formatted_xml
|
||||
}
|
||||
|
||||
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""XML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
|
||||
def xml_to_dict(element):
|
||||
"""递归转换XML元素为字典"""
|
||||
result = {}
|
||||
|
||||
# 处理属性
|
||||
if element.attrib:
|
||||
result.update(element.attrib)
|
||||
|
||||
# 处理文本内容
|
||||
if element.text and element.text.strip():
|
||||
if len(element) == 0: # 叶子节点
|
||||
return element.text.strip()
|
||||
else:
|
||||
result['text'] = element.text.strip()
|
||||
|
||||
# 处理子元素
|
||||
for child in element:
|
||||
child_data = xml_to_dict(child)
|
||||
if child.tag in result:
|
||||
# 如果标签已存在,转换为列表
|
||||
if not isinstance(result[child.tag], list):
|
||||
result[child.tag] = [result[child.tag]]
|
||||
result[child.tag].append(child_data)
|
||||
else:
|
||||
result[child.tag] = child_data
|
||||
|
||||
return result
|
||||
|
||||
root = ET.fromstring(input_data)
|
||||
data = {root.tag: xml_to_dict(root)}
|
||||
json_output = json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"original_format": "xml",
|
||||
"target_format": "json",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并JSON"""
|
||||
merge_data = kwargs.get("merge_data")
|
||||
if not merge_data:
|
||||
raise ValueError("merge_data 参数是必需的")
|
||||
|
||||
data1 = json.loads(input_data)
|
||||
data2 = json.loads(merge_data)
|
||||
|
||||
def deep_merge(dict1, dict2):
|
||||
"""深度合并字典"""
|
||||
result = dict1.copy()
|
||||
for key, value in dict2.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
if isinstance(data1, dict) and isinstance(data2, dict):
|
||||
merged = deep_merge(data1, data2)
|
||||
elif isinstance(data1, list) and isinstance(data2, list):
|
||||
merged = data1 + data2
|
||||
else:
|
||||
raise ValueError("无法合并不同类型的数据")
|
||||
|
||||
merged_json = json.dumps(merged, indent=2, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"operation": "merge",
|
||||
"original_size": len(input_data),
|
||||
"merge_size": len(merge_data),
|
||||
"result_size": len(merged_json),
|
||||
"merged_data": merged_json
|
||||
}
|
||||
|
||||
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取JSON路径"""
|
||||
json_path = kwargs.get("json_path")
|
||||
if not json_path:
|
||||
raise ValueError("json_path 参数是必需的")
|
||||
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 简单的JSONPath实现(支持基本的点号路径)
|
||||
try:
|
||||
result = data
|
||||
if json_path.startswith('$.'):
|
||||
path_parts = json_path[2:].split('.')
|
||||
else:
|
||||
path_parts = json_path.split('.')
|
||||
|
||||
for part in path_parts:
|
||||
if part.isdigit():
|
||||
result = result[int(part)]
|
||||
else:
|
||||
result = result[part]
|
||||
|
||||
extracted_json = json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"operation": "extract",
|
||||
"json_path": json_path,
|
||||
"found": True,
|
||||
"extracted_data": extracted_json,
|
||||
"data_type": type(result).__name__
|
||||
}
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
return {
|
||||
"operation": "extract",
|
||||
"json_path": json_path,
|
||||
"found": False,
|
||||
"error": str(e),
|
||||
"extracted_data": None
|
||||
}
|
||||
|
||||
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
|
||||
"""分析JSON结构"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
"type": "object",
|
||||
"keys": len(data),
|
||||
"depth": depth,
|
||||
"children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()}
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return {
|
||||
"type": "array",
|
||||
"length": len(data),
|
||||
"depth": depth,
|
||||
"item_types": list(set(type(item).__name__ for item in data))
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": type(data).__name__,
|
||||
"depth": depth,
|
||||
"value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data)
|
||||
}
|
||||
327
api/app/core/tools/builtin/mineru_tool.py
Normal file
327
api/app/core/tools/builtin/mineru_tool.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""MinerU PDF解析工具"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class MinerUTool(BuiltinTool):
|
||||
"""MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mineru_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "MinerU - PDF解析工具:PDF解析、表格提取、图片识别、文本提取"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["api_key", "api_url"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="file_content",
|
||||
type=ParameterType.STRING,
|
||||
description="PDF文件内容(Base64编码)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="file_url",
|
||||
type=ParameterType.STRING,
|
||||
description="PDF文件URL",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="parse_mode",
|
||||
type=ParameterType.STRING,
|
||||
description="解析模式",
|
||||
required=False,
|
||||
default="auto",
|
||||
enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="extract_images",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否提取图片",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="extract_tables",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否提取表格",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_range",
|
||||
type=ParameterType.STRING,
|
||||
description="页面范围(如:1-5, 1,3,5)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出格式",
|
||||
required=False,
|
||||
default="json",
|
||||
enum=["json", "markdown", "html", "text"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MinerU PDF解析"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
file_content = kwargs.get("file_content")
|
||||
file_url = kwargs.get("file_url")
|
||||
|
||||
if not file_content and not file_url:
|
||||
raise ValueError("必须提供 file_content 或 file_url 参数")
|
||||
|
||||
if operation == "parse_pdf":
|
||||
result = await self._parse_pdf(kwargs)
|
||||
elif operation == "extract_text":
|
||||
result = await self._extract_text(kwargs)
|
||||
elif operation == "extract_tables":
|
||||
result = await self._extract_tables(kwargs)
|
||||
elif operation == "extract_images":
|
||||
result = await self._extract_images(kwargs)
|
||||
elif operation == "analyze_layout":
|
||||
result = await self._analyze_layout(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MINERU_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""完整PDF解析"""
|
||||
parse_mode = kwargs.get("parse_mode", "auto")
|
||||
extract_images = kwargs.get("extract_images", True)
|
||||
extract_tables = kwargs.get("extract_tables", True)
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "json")
|
||||
|
||||
# 构建请求参数
|
||||
request_data = {
|
||||
"parse_mode": parse_mode,
|
||||
"extract_images": extract_images,
|
||||
"extract_tables": extract_tables,
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
# 添加文件数据
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
# 调用MinerU API
|
||||
result = await self._call_mineru_api("parse", request_data)
|
||||
|
||||
return {
|
||||
"operation": "parse_pdf",
|
||||
"parse_mode": parse_mode,
|
||||
"total_pages": result.get("total_pages", 0),
|
||||
"processed_pages": result.get("processed_pages", 0),
|
||||
"text_content": result.get("text_content", ""),
|
||||
"tables": result.get("tables", []),
|
||||
"images": result.get("images", []),
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"metadata": result.get("metadata", {}),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取文本"""
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "text")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_text",
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_text", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_text",
|
||||
"total_pages": result.get("total_pages", 0),
|
||||
"text_content": result.get("text_content", ""),
|
||||
"word_count": len(result.get("text_content", "").split()),
|
||||
"character_count": len(result.get("text_content", "")),
|
||||
"pages_text": result.get("pages_text", [])
|
||||
}
|
||||
|
||||
async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取表格"""
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "json")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_tables",
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_tables", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_tables",
|
||||
"total_tables": result.get("total_tables", 0),
|
||||
"tables": result.get("tables", []),
|
||||
"table_locations": result.get("table_locations", [])
|
||||
}
|
||||
|
||||
async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取图片"""
|
||||
page_range = kwargs.get("page_range")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_images"
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_images", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_images",
|
||||
"total_images": result.get("total_images", 0),
|
||||
"images": result.get("images", []),
|
||||
"image_locations": result.get("image_locations", [])
|
||||
}
|
||||
|
||||
async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""分析布局"""
|
||||
page_range = kwargs.get("page_range")
|
||||
|
||||
request_data = {
|
||||
"operation": "analyze_layout"
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("analyze_layout", request_data)
|
||||
|
||||
return {
|
||||
"operation": "analyze_layout",
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"page_layouts": result.get("page_layouts", []),
|
||||
"text_blocks": result.get("text_blocks", []),
|
||||
"image_blocks": result.get("image_blocks", []),
|
||||
"table_blocks": result.get("table_blocks", [])
|
||||
}
|
||||
|
||||
async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用MinerU API"""
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
timeout_seconds = self.get_config_parameter("timeout", 60)
|
||||
|
||||
if not api_key or not api_url:
|
||||
raise ValueError("MinerU API配置未完成")
|
||||
|
||||
# 构建完整URL
|
||||
url = f"{api_url.rstrip('/')}/{endpoint}"
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get("success", True):
|
||||
return result.get("data", result)
|
||||
else:
|
||||
raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP错误 {response.status}: {error_text}")
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not api_key or not api_url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API配置未完成"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接配置有效",
|
||||
"api_url": api_url,
|
||||
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
401
api/app/core/tools/builtin/textin_tool.py
Normal file
401
api/app/core/tools/builtin/textin_tool.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""TextIn OCR文字识别工具"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class TextInTool(BuiltinTool):
|
||||
"""TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "textin_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "TextIn - OCR文字识别:通用OCR、手写识别、多语言支持、高精度识别"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["app_id", "secret_key", "api_url"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="image_content",
|
||||
type=ParameterType.STRING,
|
||||
description="图片内容(Base64编码)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="图片URL",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="language",
|
||||
type=ParameterType.STRING,
|
||||
description="识别语言",
|
||||
required=False,
|
||||
default="auto",
|
||||
enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="recognition_mode",
|
||||
type=ParameterType.STRING,
|
||||
description="识别模式",
|
||||
required=False,
|
||||
default="general",
|
||||
enum=["general", "accurate", "handwriting", "formula", "table", "document"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="return_location",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否返回文字位置信息",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="return_confidence",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否返回置信度",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="merge_lines",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否合并行",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出格式",
|
||||
required=False,
|
||||
default="text",
|
||||
enum=["text", "json", "structured"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行TextIn OCR识别"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
image_content = kwargs.get("image_content")
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
if not image_content and not image_url:
|
||||
raise ValueError("必须提供 image_content 或 image_url 参数")
|
||||
|
||||
language = kwargs.get("language", "auto")
|
||||
recognition_mode = kwargs.get("recognition_mode", "general")
|
||||
return_location = kwargs.get("return_location", False)
|
||||
return_confidence = kwargs.get("return_confidence", True)
|
||||
merge_lines = kwargs.get("merge_lines", True)
|
||||
output_format = kwargs.get("output_format", "text")
|
||||
|
||||
# 根据识别模式调用不同的API
|
||||
if recognition_mode == "general":
|
||||
result = await self._general_ocr(kwargs)
|
||||
elif recognition_mode == "accurate":
|
||||
result = await self._accurate_ocr(kwargs)
|
||||
elif recognition_mode == "handwriting":
|
||||
result = await self._handwriting_ocr(kwargs)
|
||||
elif recognition_mode == "formula":
|
||||
result = await self._formula_ocr(kwargs)
|
||||
elif recognition_mode == "table":
|
||||
result = await self._table_ocr(kwargs)
|
||||
elif recognition_mode == "document":
|
||||
result = await self._document_ocr(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的识别模式: {recognition_mode}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="TEXTIN_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通用OCR识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"merge_lines": kwargs.get("merge_lines", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("general_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""高精度OCR识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"merge_lines": kwargs.get("merge_lines", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("accurate_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""手写体识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("handwriting_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""公式识别"""
|
||||
request_data = {
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"output_latex": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("formula_ocr", request_data)
|
||||
|
||||
return self._format_formula_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""表格识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"output_excel": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("table_ocr", request_data)
|
||||
|
||||
return self._format_table_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""文档识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"layout_analysis": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("document_ocr", request_data)
|
||||
|
||||
return self._format_document_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None:
|
||||
"""格式化OCR结果"""
|
||||
lines = result.get("lines", [])
|
||||
|
||||
if output_format == "text":
|
||||
text_content = "\n".join([line.get("text", "") for line in lines])
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"text_content": text_content,
|
||||
"line_count": len(lines),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
elif output_format == "json":
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"lines": lines,
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
elif output_format == "structured":
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"text_content": "\n".join([line.get("text", "") for line in lines]),
|
||||
"structured_data": {
|
||||
"lines": lines,
|
||||
"paragraphs": self._group_lines_to_paragraphs(lines),
|
||||
"statistics": {
|
||||
"line_count": len(lines),
|
||||
"word_count": sum(len(line.get("text", "").split()) for line in lines),
|
||||
"character_count": sum(len(line.get("text", "")) for line in lines)
|
||||
}
|
||||
},
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化公式识别结果"""
|
||||
formulas = result.get("formulas", [])
|
||||
|
||||
return {
|
||||
"recognition_mode": "formula",
|
||||
"formula_count": len(formulas),
|
||||
"formulas": formulas,
|
||||
"latex_content": "\n".join([f.get("latex", "") for f in formulas]),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化表格识别结果"""
|
||||
tables = result.get("tables", [])
|
||||
|
||||
return {
|
||||
"recognition_mode": "table",
|
||||
"table_count": len(tables),
|
||||
"tables": tables,
|
||||
"excel_data": result.get("excel_data"),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化文档识别结果"""
|
||||
return {
|
||||
"recognition_mode": "document",
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"text_blocks": result.get("text_blocks", []),
|
||||
"image_blocks": result.get("image_blocks", []),
|
||||
"table_blocks": result.get("table_blocks", []),
|
||||
"full_text": result.get("full_text", ""),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将行分组为段落"""
|
||||
paragraphs = []
|
||||
current_paragraph = []
|
||||
|
||||
for line in lines:
|
||||
text = line.get("text", "").strip()
|
||||
if text:
|
||||
current_paragraph.append(line)
|
||||
else:
|
||||
if current_paragraph:
|
||||
paragraphs.append({
|
||||
"text": " ".join([l.get("text", "") for l in current_paragraph]),
|
||||
"lines": current_paragraph
|
||||
})
|
||||
current_paragraph = []
|
||||
|
||||
if current_paragraph:
|
||||
paragraphs.append({
|
||||
"text": " ".join([l.get("text", "") for l in current_paragraph]),
|
||||
"lines": current_paragraph
|
||||
})
|
||||
|
||||
return paragraphs
|
||||
|
||||
async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用TextIn API"""
|
||||
app_id = self.get_config_parameter("app_id")
|
||||
secret_key = self.get_config_parameter("secret_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not app_id or not secret_key or not api_url:
|
||||
raise ValueError("TextIn API配置未完成")
|
||||
|
||||
# 构建完整URL
|
||||
url = f"{api_url.rstrip('/')}/{endpoint}"
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"X-App-Id": app_id,
|
||||
"X-Secret-Key": secret_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get("code") == 200:
|
||||
return result.get("data", result)
|
||||
else:
|
||||
raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP错误 {response.status}: {error_text}")
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
app_id = self.get_config_parameter("app_id")
|
||||
secret_key = self.get_config_parameter("secret_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not app_id or not secret_key or not api_url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API配置未完成"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接配置有效",
|
||||
"api_url": api_url,
|
||||
"app_id": app_id,
|
||||
"secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
485
api/app/core/tools/chain_manager.py
Normal file
485
api/app/core/tools/chain_manager.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.core.tools.base import ToolResult
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ChainExecutionMode(str, Enum):
|
||||
"""链执行模式"""
|
||||
SEQUENTIAL = "sequential" # 顺序执行
|
||||
PARALLEL = "parallel" # 并行执行
|
||||
CONDITIONAL = "conditional" # 条件执行
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainStep:
|
||||
"""链步骤定义"""
|
||||
tool_id: str
|
||||
parameters: Dict[str, Any]
|
||||
condition: Optional[str] = None # 执行条件
|
||||
output_mapping: Optional[Dict[str, str]] = None # 输出映射
|
||||
error_handling: str = "stop" # 错误处理:stop, continue, retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainDefinition:
|
||||
"""工具链定义"""
|
||||
name: str
|
||||
description: str
|
||||
steps: List[ChainStep]
|
||||
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
|
||||
global_timeout: Optional[float] = None
|
||||
retry_policy: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChainExecutionContext:
|
||||
"""链执行上下文"""
|
||||
|
||||
def __init__(self, chain_id: str):
|
||||
self.chain_id = chain_id
|
||||
self.variables: Dict[str, Any] = {}
|
||||
self.step_results: Dict[int, ToolResult] = {}
|
||||
self.current_step = 0
|
||||
self.is_completed = False
|
||||
self.is_failed = False
|
||||
self.error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ChainManager:
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
|
||||
def __init__(self, executor: ToolExecutor):
|
||||
"""初始化工具链管理器
|
||||
|
||||
Args:
|
||||
executor: 工具执行器
|
||||
"""
|
||||
self.executor = executor
|
||||
self._chains: Dict[str, ChainDefinition] = {}
|
||||
self._running_chains: Dict[str, ChainExecutionContext] = {}
|
||||
|
||||
def register_chain(self, chain: ChainDefinition) -> bool:
|
||||
"""注册工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 验证工具链定义
|
||||
validation_result = self._validate_chain(chain)
|
||||
if not validation_result[0]:
|
||||
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
|
||||
return False
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"工具链注册成功: {chain.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def unregister_chain(self, chain_name: str) -> bool:
|
||||
"""注销工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
if chain_name in self._chains:
|
||||
del self._chains[chain_name]
|
||||
logger.info(f"工具链注销成功: {chain_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_chains(self) -> List[Dict[str, Any]]:
|
||||
"""列出所有工具链
|
||||
|
||||
Returns:
|
||||
工具链信息列表
|
||||
"""
|
||||
chains = []
|
||||
for name, chain in self._chains.items():
|
||||
chains.append({
|
||||
"name": name,
|
||||
"description": chain.description,
|
||||
"step_count": len(chain.steps),
|
||||
"execution_mode": chain.execution_mode.value,
|
||||
"global_timeout": chain.global_timeout
|
||||
})
|
||||
|
||||
return chains
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None,
|
||||
chain_id: Optional[str] = None
|
||||
) -> Dict[str, Any] | None:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
initial_variables: 初始变量
|
||||
chain_id: 链执行ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if chain_name not in self._chains:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具链不存在: {chain_name}",
|
||||
"chain_id": chain_id
|
||||
}
|
||||
|
||||
chain = self._chains[chain_name]
|
||||
|
||||
# 生成链ID
|
||||
if not chain_id:
|
||||
import uuid
|
||||
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ChainExecutionContext(chain_id)
|
||||
context.variables = initial_variables or {}
|
||||
self._running_chains[chain_id] = context
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
|
||||
|
||||
# 根据执行模式执行
|
||||
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
|
||||
result = await self._execute_sequential(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
|
||||
result = await self._execute_parallel(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
|
||||
result = await self._execute_conditional(chain, context)
|
||||
else:
|
||||
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
|
||||
|
||||
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"chain_id": chain_id,
|
||||
"completed_steps": context.current_step,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
if chain_id in self._running_chains:
|
||||
del self._running_chains[chain_id]
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""顺序执行工具链"""
|
||||
for i, step in enumerate(chain.steps):
|
||||
context.current_step = i
|
||||
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
logger.debug(f"跳过步骤 {i}: 条件不满足")
|
||||
continue
|
||||
|
||||
# 准备参数
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
|
||||
context.step_results[i] = result
|
||||
|
||||
# 处理输出映射
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 处理执行失败
|
||||
if not result.success:
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = result.error
|
||||
break
|
||||
elif step.error_handling == "continue":
|
||||
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
|
||||
continue
|
||||
elif step.error_handling == "retry":
|
||||
# 简单重试逻辑
|
||||
retry_result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
context.step_results[i] = retry_result
|
||||
if not retry_result.success and step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = retry_result.error
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"步骤 {i} 执行异常: {e}")
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
break
|
||||
|
||||
context.is_completed = not context.is_failed
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": context.current_step + 1,
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""并行执行工具链"""
|
||||
# 准备所有步骤的执行配置
|
||||
execution_configs = []
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
continue
|
||||
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
execution_configs.append({
|
||||
"step_index": i,
|
||||
"tool_id": step.tool_id,
|
||||
"parameters": parameters
|
||||
})
|
||||
|
||||
# 并行执行所有步骤
|
||||
try:
|
||||
results = await self.executor.execute_tools_batch(execution_configs)
|
||||
|
||||
# 处理结果
|
||||
for i, result in enumerate(results):
|
||||
step_index = execution_configs[i]["step_index"]
|
||||
context.step_results[step_index] = result
|
||||
|
||||
# 处理输出映射
|
||||
step = chain.steps[step_index]
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 检查是否有失败的步骤
|
||||
failed_steps = [i for i, result in context.step_results.items() if not result.success]
|
||||
|
||||
context.is_completed = len(failed_steps) == 0
|
||||
if failed_steps:
|
||||
context.error_message = f"步骤 {failed_steps} 执行失败"
|
||||
|
||||
except Exception as e:
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": len(context.step_results),
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""条件执行工具链"""
|
||||
# 条件执行类似于顺序执行,但更严格地检查条件
|
||||
return await self._execute_sequential(chain, context)
|
||||
|
||||
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具链定义
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not chain.name:
|
||||
return False, "工具链名称不能为空"
|
||||
|
||||
if not chain.steps:
|
||||
return False, "工具链必须包含至少一个步骤"
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
if not step.tool_id:
|
||||
return False, f"步骤 {i} 缺少工具ID"
|
||||
|
||||
if step.error_handling not in ["stop", "continue", "retry"]:
|
||||
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
|
||||
|
||||
return True, None
|
||||
|
||||
def _prepare_parameters(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""准备参数(支持变量替换)
|
||||
|
||||
Args:
|
||||
parameters: 原始参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
处理后的参数
|
||||
"""
|
||||
prepared = {}
|
||||
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_name = value[2:-1]
|
||||
if var_name in context.variables:
|
||||
prepared[key] = context.variables[var_name]
|
||||
else:
|
||||
prepared[key] = value # 保持原值
|
||||
else:
|
||||
prepared[key] = value
|
||||
|
||||
return prepared
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str,
|
||||
context: ChainExecutionContext
|
||||
) -> bool:
|
||||
"""评估执行条件
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
条件是否满足
|
||||
"""
|
||||
try:
|
||||
# 简单的条件评估(可以扩展为更复杂的表达式解析)
|
||||
# 支持格式:variable == value, variable != value, variable > value 等
|
||||
|
||||
if "==" in condition:
|
||||
var_name, expected_value = condition.split("==", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) == expected_value
|
||||
|
||||
elif "!=" in condition:
|
||||
var_name, expected_value = condition.split("!=", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) != expected_value
|
||||
|
||||
elif condition in context.variables:
|
||||
# 简单的布尔检查
|
||||
return bool(context.variables[condition])
|
||||
|
||||
else:
|
||||
# 默认为真
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"条件评估失败: {condition}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _apply_output_mapping(
|
||||
self,
|
||||
mapping: Dict[str, str],
|
||||
output_data: Any,
|
||||
context: ChainExecutionContext
|
||||
):
|
||||
"""应用输出映射
|
||||
|
||||
Args:
|
||||
mapping: 输出映射配置
|
||||
output_data: 输出数据
|
||||
context: 执行上下文
|
||||
"""
|
||||
try:
|
||||
if isinstance(output_data, dict):
|
||||
for source_key, target_var in mapping.items():
|
||||
if source_key in output_data:
|
||||
context.variables[target_var] = output_data[source_key]
|
||||
else:
|
||||
# 如果输出不是字典,将整个输出映射到指定变量
|
||||
if "result" in mapping:
|
||||
context.variables[mapping["result"]] = output_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出映射失败: {e}")
|
||||
|
||||
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
|
||||
"""序列化工具结果
|
||||
|
||||
Args:
|
||||
result: 工具结果
|
||||
|
||||
Returns:
|
||||
序列化的结果
|
||||
"""
|
||||
return {
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
def get_running_chains(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的工具链
|
||||
|
||||
Returns:
|
||||
运行中的工具链列表
|
||||
"""
|
||||
chains = []
|
||||
for chain_id, context in self._running_chains.items():
|
||||
chains.append({
|
||||
"chain_id": chain_id,
|
||||
"current_step": context.current_step,
|
||||
"is_completed": context.is_completed,
|
||||
"is_failed": context.is_failed,
|
||||
"variables_count": len(context.variables),
|
||||
"completed_steps": len(context.step_results)
|
||||
})
|
||||
|
||||
return chains
|
||||
264
api/app/core/tools/config_manager.py
Normal file
264
api/app/core/tools/config_manager.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""工具配置管理器 - 管理工具配置的加载和验证"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolConfigSchema(BaseModel):
|
||||
"""工具配置基础Schema"""
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
version: str = "1.0.0"
|
||||
enabled: bool = True
|
||||
parameters: Dict[str, Any] = {}
|
||||
tags: list[str] = []
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class BuiltinToolConfigSchema(ToolConfigSchema):
|
||||
"""内置工具配置Schema"""
|
||||
tool_class: str
|
||||
tool_type: str = "builtin"
|
||||
|
||||
|
||||
class CustomToolConfigSchema(ToolConfigSchema):
|
||||
"""自定义工具配置Schema"""
|
||||
schema_url: Optional[str] = None
|
||||
schema_content: Optional[Dict[str, Any]] = None
|
||||
auth_type: str = "none"
|
||||
auth_config: Dict[str, Any] = {}
|
||||
base_url: Optional[str] = None
|
||||
timeout: int = 30
|
||||
tool_type: str = "custom"
|
||||
|
||||
|
||||
class MCPToolConfigSchema(ToolConfigSchema):
|
||||
"""MCP工具配置Schema"""
|
||||
server_url: str
|
||||
connection_config: Dict[str, Any] = {}
|
||||
available_tools: list[str] = []
|
||||
tool_type: str = "mcp"
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""工具配置管理器"""
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""初始化配置管理器
|
||||
|
||||
Args:
|
||||
config_dir: 配置文件目录,默认使用系统配置
|
||||
"""
|
||||
self.config_dir = Path(config_dir or self._get_default_config_dir())
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
|
||||
|
||||
def _get_default_config_dir(self) -> str:
|
||||
"""获取默认配置目录"""
|
||||
# 获取tools目录下的configs子目录
|
||||
tools_dir = Path(__file__).parent
|
||||
return str(tools_dir / "configs")
|
||||
|
||||
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
|
||||
"""加载内置工具配置
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
configs = {}
|
||||
builtin_dir = self.config_dir / "builtin"
|
||||
|
||||
if not builtin_dir.exists():
|
||||
logger.info("内置工具配置目录不存在,创建默认配置")
|
||||
self._create_default_builtin_configs(builtin_dir)
|
||||
|
||||
for config_file in builtin_dir.glob("*.json"):
|
||||
try:
|
||||
config_data = self._load_config_file(config_file)
|
||||
config = BuiltinToolConfigSchema(**config_data)
|
||||
configs[config.name] = config
|
||||
logger.debug(f"加载内置工具配置: {config.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
|
||||
|
||||
return configs
|
||||
|
||||
def load_builtin_tools_config(self) -> Dict[str, Any]:
|
||||
"""加载全局内置工具配置(兼容原有接口)
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
config_file = self.config_dir / "builtin_tools.json"
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
|
||||
"""确保内置工具已初始化到数据库
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
db_session: 数据库会话
|
||||
tool_config_model: ToolConfig模型类
|
||||
builtin_tool_config_model: BuiltinToolConfig模型类
|
||||
tool_type_enum: ToolType枚举
|
||||
tool_status_enum: ToolStatus枚举
|
||||
"""
|
||||
# 检查是否已初始化
|
||||
existing_count = db_session.query(tool_config_model).filter(
|
||||
tool_config_model.tenant_id == tenant_id,
|
||||
tool_config_model.tool_type == tool_type_enum.BUILTIN
|
||||
).count()
|
||||
|
||||
if existing_count > 0:
|
||||
return # 已初始化
|
||||
|
||||
# 加载全局配置
|
||||
builtin_tools = self.load_builtin_tools_config()
|
||||
|
||||
# 为租户创建内置工具记录
|
||||
for tool_key, tool_info in builtin_tools.items():
|
||||
# 设置初始状态
|
||||
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
|
||||
|
||||
tool_config = tool_config_model(
|
||||
name=tool_info['name'],
|
||||
description=tool_info['description'],
|
||||
tool_type=tool_type_enum.BUILTIN,
|
||||
tenant_id=tenant_id,
|
||||
status=initial_status
|
||||
)
|
||||
db_session.add(tool_config)
|
||||
db_session.flush()
|
||||
|
||||
builtin_config = builtin_tool_config_model(
|
||||
id=tool_config.id,
|
||||
tool_class=tool_info['tool_class'],
|
||||
parameters={}
|
||||
)
|
||||
db_session.add(builtin_config)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
|
||||
|
||||
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
|
||||
"""保存工具配置
|
||||
|
||||
Args:
|
||||
config: 工具配置
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
config_dir = self.config_dir / tool_type
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config_file = config_dir / f"{config.name}.json"
|
||||
config_data = config.model_dump()
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
|
||||
"""删除工具配置
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
config_file = self.config_dir / tool_type / f"{tool_name}.json"
|
||||
|
||||
if config_file.exists():
|
||||
config_file.unlink()
|
||||
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具配置
|
||||
|
||||
Args:
|
||||
config_data: 配置数据
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
schema_map = {
|
||||
"builtin": BuiltinToolConfigSchema,
|
||||
"custom": CustomToolConfigSchema,
|
||||
"mcp": MCPToolConfigSchema
|
||||
}
|
||||
|
||||
schema_class = schema_map.get(tool_type)
|
||||
if not schema_class:
|
||||
return False, f"不支持的工具类型: {tool_type}"
|
||||
|
||||
# 验证配置
|
||||
schema_class(**config_data)
|
||||
return True, None
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
|
||||
return False, f"配置验证失败: {error_msg}"
|
||||
except Exception as e:
|
||||
return False, f"配置验证异常: {str(e)}"
|
||||
|
||||
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
|
||||
"""加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置数据字典
|
||||
"""
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def _create_default_builtin_configs(self, builtin_dir: Path):
|
||||
"""创建默认内置工具配置
|
||||
|
||||
Args:
|
||||
builtin_dir: 内置工具配置目录
|
||||
"""
|
||||
builtin_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
|
||||
# 配置文件已经通过其他方式创建,这里只需要确保目录存在
|
||||
14
api/app/core/tools/configs/builtin/baidu_search_tool.json
Normal file
14
api/app/core/tools/configs/builtin/baidu_search_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "baidu_search_tool",
|
||||
"description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "BaiduSearchTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"search_type": "web"
|
||||
},
|
||||
"tags": ["search", "web", "baidu", "builtin"]
|
||||
}
|
||||
12
api/app/core/tools/configs/builtin/datetime_tool.json
Normal file
12
api/app/core/tools/configs/builtin/datetime_tool.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "datetime_tool",
|
||||
"description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "DateTimeTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"timezone": "UTC"
|
||||
},
|
||||
"tags": ["time", "utility", "builtin"]
|
||||
}
|
||||
12
api/app/core/tools/configs/builtin/json_tool.json
Normal file
12
api/app/core/tools/configs/builtin/json_tool.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "json_tool",
|
||||
"description": "JSON工具 - 数据格式处理:提供JSON格式化、压缩、验证、格式转换",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "JsonTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"indent": 2
|
||||
},
|
||||
"tags": ["json", "data", "utility", "builtin"]
|
||||
}
|
||||
14
api/app/core/tools/configs/builtin/mineru_tool.json
Normal file
14
api/app/core/tools/configs/builtin/mineru_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "mineru_tool",
|
||||
"description": "MinerU PDF解析工具 - 文档处理:提供PDF解析、表格提取、图片识别、文本提取功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "MinerUTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": "",
|
||||
"parse_mode": "auto",
|
||||
"timeout": 60
|
||||
},
|
||||
"tags": ["pdf", "document", "ocr", "builtin"]
|
||||
}
|
||||
14
api/app/core/tools/configs/builtin/textin_tool.json
Normal file
14
api/app/core/tools/configs/builtin/textin_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "textin_tool",
|
||||
"description": "TextIn OCR工具 - 图像识别:提供通用OCR、手写识别、多语言支持功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "TextInTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"app_id": "",
|
||||
"language": "auto",
|
||||
"recognition_mode": "general"
|
||||
},
|
||||
"tags": ["ocr", "image", "text", "builtin"]
|
||||
}
|
||||
60
api/app/core/tools/configs/builtin_tools.json
Normal file
60
api/app/core/tools/configs/builtin_tools.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"datetime": {
|
||||
"name": "时间工具",
|
||||
"description": "获取当前时间、日期计算",
|
||||
"tool_class": "DateTimeTool",
|
||||
"category": "utility",
|
||||
"requires_config": false,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {}
|
||||
},
|
||||
"json_converter": {
|
||||
"name": "JSON转换工具",
|
||||
"description": "JSON数据格式化和转换",
|
||||
"tool_class": "JsonTool",
|
||||
"category": "utility",
|
||||
"requires_config": false,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {}
|
||||
},
|
||||
"baidu_search": {
|
||||
"name": "百度搜索",
|
||||
"description": "百度网页搜索服务",
|
||||
"tool_class": "BaiduSearchTool",
|
||||
"category": "search",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
},
|
||||
"mineru": {
|
||||
"name": "MinerU",
|
||||
"description": "PDF文档解析工具",
|
||||
"tool_class": "MinerUTool",
|
||||
"category": "document",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true},
|
||||
"base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"}
|
||||
}
|
||||
},
|
||||
"textin": {
|
||||
"name": "TextIn",
|
||||
"description": "OCR文字识别服务",
|
||||
"tool_class": "TextInTool",
|
||||
"category": "ocr",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
}
|
||||
}
|
||||
11
api/app/core/tools/custom/__init__.py
Normal file
11
api/app/core/tools/custom/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""自定义工具模块"""
|
||||
|
||||
from .base import CustomTool
|
||||
from .schema_parser import OpenAPISchemaParser
|
||||
from .auth_manager import AuthManager
|
||||
|
||||
__all__ = [
|
||||
"CustomTool",
|
||||
"OpenAPISchemaParser",
|
||||
"AuthManager"
|
||||
]
|
||||
525
api/app/core/tools/custom/auth_manager.py
Normal file
525
api/app/core/tools/custom/auth_manager.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""认证管理器 - 处理自定义工具的认证配置"""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import AuthType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""认证管理器 - 支持多种认证方式"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化认证管理器"""
|
||||
self.supported_auth_types = [
|
||||
AuthType.NONE,
|
||||
AuthType.API_KEY,
|
||||
AuthType.BEARER_TOKEN
|
||||
]
|
||||
|
||||
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
if auth_type not in self.supported_auth_types:
|
||||
return False, f"不支持的认证类型: {auth_type}"
|
||||
|
||||
if auth_type == AuthType.NONE:
|
||||
return True, ""
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._validate_api_key_config(auth_config)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._validate_bearer_token_config(auth_config)
|
||||
|
||||
return False, "未知的认证类型"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证认证配置时出错: {e}"
|
||||
|
||||
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证API Key认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
if not api_key:
|
||||
return False, "API Key不能为空"
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
return False, "API Key必须是字符串"
|
||||
|
||||
# 验证key名称
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if not isinstance(key_name, str):
|
||||
return False, "API Key名称必须是字符串"
|
||||
|
||||
# 验证位置
|
||||
key_location = auth_config.get("location", "header")
|
||||
if key_location not in ["header", "query", "cookie"]:
|
||||
return False, "API Key位置必须是 header、query 或 cookie"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证Bearer Token认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
if not token:
|
||||
return False, "Bearer Token不能为空"
|
||||
|
||||
if not isinstance(token, str):
|
||||
return False, "Bearer Token必须是字符串"
|
||||
|
||||
return True, ""
|
||||
|
||||
def apply_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用认证到请求
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
try:
|
||||
if auth_type == AuthType.NONE:
|
||||
return url, headers, params
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._apply_api_key_auth(auth_config, url, headers, params)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._apply_bearer_token_auth(auth_config, url, headers, params)
|
||||
|
||||
else:
|
||||
logger.warning(f"不支持的认证类型: {auth_type}")
|
||||
return url, headers, params
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"应用认证时出错: {e}")
|
||||
return url, headers, params
|
||||
|
||||
def _apply_api_key_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用API Key认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
location = auth_config.get("location", "header")
|
||||
|
||||
if location == "header":
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif location == "query":
|
||||
# 添加到URL查询参数
|
||||
separator = "&" if "?" in url else "?"
|
||||
encoded_key = quote(str(api_key))
|
||||
url += f"{separator}{key_name}={encoded_key}"
|
||||
|
||||
elif location == "cookie":
|
||||
# 添加到Cookie头
|
||||
cookie_value = f"{key_name}={api_key}"
|
||||
if "Cookie" in headers:
|
||||
headers["Cookie"] += f"; {cookie_value}"
|
||||
else:
|
||||
headers["Cookie"] = cookie_value
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def _apply_bearer_token_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用Bearer Token认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""加密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
encryption_key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的认证配置
|
||||
"""
|
||||
try:
|
||||
encrypted_config = auth_config.copy()
|
||||
|
||||
# 需要加密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in encrypted_config:
|
||||
value = encrypted_config[field]
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_value = self._encrypt_string(value, encryption_key)
|
||||
encrypted_config[field] = encrypted_value
|
||||
encrypted_config[f"{field}_encrypted"] = True
|
||||
|
||||
return encrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密认证配置失败: {e}")
|
||||
return auth_config
|
||||
|
||||
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""解密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
encrypted_config: 加密的认证配置
|
||||
encryption_key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的认证配置
|
||||
"""
|
||||
try:
|
||||
decrypted_config = encrypted_config.copy()
|
||||
|
||||
# 需要解密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
|
||||
encrypted_value = decrypted_config[field]
|
||||
if isinstance(encrypted_value, str) and encrypted_value:
|
||||
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
|
||||
decrypted_config[field] = decrypted_value
|
||||
# 移除加密标记
|
||||
decrypted_config.pop(f"{field}_encrypted", None)
|
||||
|
||||
return decrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密认证配置失败: {e}")
|
||||
return encrypted_config
|
||||
|
||||
def _encrypt_string(self, value: str, key: str) -> str:
|
||||
"""加密字符串
|
||||
|
||||
Args:
|
||||
value: 要加密的字符串
|
||||
key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的字符串(Base64编码)
|
||||
"""
|
||||
try:
|
||||
# 使用HMAC-SHA256进行简单加密
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
# 生成HMAC
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
signature = hmac_obj.hexdigest()
|
||||
|
||||
# 组合原始值和签名,然后Base64编码
|
||||
combined = f"{value}:{signature}"
|
||||
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
|
||||
|
||||
return encrypted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密字符串失败: {e}")
|
||||
return value
|
||||
|
||||
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
|
||||
"""解密字符串
|
||||
|
||||
Args:
|
||||
encrypted_value: 加密的字符串
|
||||
key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的字符串
|
||||
"""
|
||||
try:
|
||||
# Base64解码
|
||||
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
|
||||
|
||||
# 分离原始值和签名
|
||||
if ':' not in decoded:
|
||||
return encrypted_value # 可能不是加密的值
|
||||
|
||||
value, signature = decoded.rsplit(':', 1)
|
||||
|
||||
# 验证签名
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
expected_signature = hmac_obj.hexdigest()
|
||||
|
||||
if signature == expected_signature:
|
||||
return value
|
||||
else:
|
||||
logger.warning("解密时签名验证失败")
|
||||
return encrypted_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密字符串失败: {e}")
|
||||
return encrypted_value
|
||||
|
||||
def test_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL(可选)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 如果没有测试URL,只验证配置
|
||||
if not test_url:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置有效",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建测试请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置测试成功",
|
||||
"auth_type": auth_type.value,
|
||||
"test_url": test_url,
|
||||
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
|
||||
"has_auth_header": "Authorization" in headers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"auth_type": auth_type.value if auth_type else "unknown"
|
||||
}
|
||||
|
||||
async def test_authentication_with_request(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str,
|
||||
timeout: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""通过实际HTTP请求测试认证
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
# 发送测试请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(test_url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
# 根据状态码判断认证是否成功
|
||||
if status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证测试成功",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 401:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 401 Unauthorized",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 403:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 403 Forbidden",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"请求成功,状态码: {status_code}",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"网络请求失败: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"测试认证时出错: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
|
||||
"""获取认证配置模板
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
|
||||
Returns:
|
||||
配置模板
|
||||
"""
|
||||
templates = {
|
||||
AuthType.NONE: {},
|
||||
|
||||
AuthType.API_KEY: {
|
||||
"api_key": "",
|
||||
"key_name": "X-API-Key",
|
||||
"location": "header", # header, query, cookie
|
||||
"description": "API Key认证配置"
|
||||
},
|
||||
|
||||
AuthType.BEARER_TOKEN: {
|
||||
"token": "",
|
||||
"description": "Bearer Token认证配置"
|
||||
}
|
||||
}
|
||||
|
||||
return templates.get(auth_type, {})
|
||||
|
||||
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""遮蔽认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
遮蔽敏感信息后的配置
|
||||
"""
|
||||
masked_config = auth_config.copy()
|
||||
|
||||
# 需要遮蔽的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in masked_config:
|
||||
value = masked_config[field]
|
||||
if isinstance(value, str) and len(value) > 4:
|
||||
# 只显示前2位和后2位
|
||||
masked_config[field] = f"{value[:2]}***{value[-2:]}"
|
||||
elif isinstance(value, str) and value:
|
||||
masked_config[field] = "***"
|
||||
|
||||
return masked_config
|
||||
318
api/app/core/tools/custom/base.py
Normal file
318
api/app/core/tools/custom/base.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""自定义工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from app.models.tool_model import ToolType, AuthType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
"""自定义工具 - 基于OpenAPI schema的工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化自定义工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.schema_content = config.get("schema_content", {})
|
||||
self.schema_url = config.get("schema_url")
|
||||
self.auth_type = AuthType(config.get("auth_type", "none"))
|
||||
self.auth_config = config.get("auth_config", {})
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
# 解析schema
|
||||
self._parsed_operations = self._parse_openapi_schema()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
|
||||
return f"custom_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("description", "自定义API工具")
|
||||
return "自定义API工具"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.CUSTOM
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加操作选择参数
|
||||
if len(self._parsed_operations) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="要执行的操作",
|
||||
required=True,
|
||||
enum=list(self._parsed_operations.keys())
|
||||
))
|
||||
|
||||
# 添加通用参数(基于第一个操作的参数)
|
||||
if self._parsed_operations:
|
||||
first_operation = next(iter(self._parsed_operations.values()))
|
||||
for param_name, param_info in first_operation.get("parameters", {}).items():
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=self._convert_openapi_type(param_info.get("type", "string")),
|
||||
description=param_info.get("description", ""),
|
||||
required=param_info.get("required", False),
|
||||
default=param_info.get("default"),
|
||||
enum=param_info.get("enum"),
|
||||
minimum=param_info.get("minimum"),
|
||||
maximum=param_info.get("maximum"),
|
||||
pattern=param_info.get("pattern")
|
||||
))
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行自定义工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确定要执行的操作
|
||||
operation_name = kwargs.get("operation")
|
||||
if not operation_name and len(self._parsed_operations) == 1:
|
||||
operation_name = next(iter(self._parsed_operations.keys()))
|
||||
|
||||
if not operation_name or operation_name not in self._parsed_operations:
|
||||
raise ValueError(f"无效的操作: {operation_name}")
|
||||
|
||||
operation = self._parsed_operations[operation_name]
|
||||
|
||||
# 构建请求
|
||||
url = self._build_request_url(operation, kwargs)
|
||||
headers = self._build_request_headers(operation)
|
||||
data = self._build_request_data(operation, kwargs)
|
||||
|
||||
# 发送HTTP请求
|
||||
result = await self._send_http_request(
|
||||
method=operation["method"],
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=data
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="CUSTOM_TOOL_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _parse_openapi_schema(self) -> Dict[str, Any]:
|
||||
"""解析OpenAPI schema"""
|
||||
operations = {}
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() in ["get", "post", "put", "delete", "patch"]:
|
||||
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
|
||||
|
||||
# 解析参数
|
||||
parameters = {}
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
param_name = param.get("name")
|
||||
param_schema = param.get("schema", {})
|
||||
parameters[param_name] = {
|
||||
"type": param_schema.get("type", "string"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"in": param.get("in", "query"),
|
||||
**param_schema
|
||||
}
|
||||
|
||||
# 解析请求体
|
||||
request_body = None
|
||||
if "requestBody" in operation:
|
||||
content = operation["requestBody"].get("content", {})
|
||||
if "application/json" in content:
|
||||
request_body = content["application/json"].get("schema", {})
|
||||
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": parameters,
|
||||
"request_body": request_body
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
|
||||
"""转换OpenAPI类型到内部类型"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(openapi_type, ParameterType.STRING)
|
||||
|
||||
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
|
||||
"""构建请求URL"""
|
||||
path = operation["path"]
|
||||
|
||||
# 替换路径参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "path" and param_name in params:
|
||||
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
|
||||
|
||||
# 构建完整URL
|
||||
if self.base_url:
|
||||
url = urljoin(self.base_url, path.lstrip("/"))
|
||||
else:
|
||||
# 从schema中获取服务器URL
|
||||
servers = self.schema_content.get("servers", [])
|
||||
if servers:
|
||||
base_url = servers[0].get("url", "")
|
||||
url = urljoin(base_url, path.lstrip("/"))
|
||||
else:
|
||||
url = path
|
||||
|
||||
# 添加查询参数
|
||||
query_params = {}
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "query" and param_name in params:
|
||||
query_params[param_name] = params[param_name]
|
||||
|
||||
if query_params:
|
||||
from urllib.parse import urlencode
|
||||
url += "?" + urlencode(query_params)
|
||||
|
||||
return url
|
||||
|
||||
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CustomTool/1.0"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
if self.auth_type == AuthType.API_KEY:
|
||||
api_key = self.auth_config.get("api_key")
|
||||
key_name = self.auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif self.auth_type == AuthType.BEARER_TOKEN:
|
||||
token = self.auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return headers
|
||||
|
||||
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建请求数据"""
|
||||
if operation["method"] in ["POST", "PUT", "PATCH"]:
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
# 构建请求体数据
|
||||
data = {}
|
||||
properties = request_body.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name in params:
|
||||
data[prop_name] = params[prop_name]
|
||||
|
||||
return data if data else None
|
||||
|
||||
return None
|
||||
|
||||
async def _send_http_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""发送HTTP请求"""
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
kwargs = {
|
||||
"headers": headers
|
||||
}
|
||||
|
||||
if data and method in ["POST", "PUT", "PATCH"]:
|
||||
kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
return await response.text()
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从URL导入OpenAPI schema创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_url": schema_url,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
# 这里应该异步加载schema,为了简化暂时返回空配置
|
||||
return cls(tool_id, config)
|
||||
|
||||
@classmethod
|
||||
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从schema字典创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_content": schema_dict,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
return cls(tool_id, config)
|
||||
477
api/app/core/tools/custom/schema_parser.py
Normal file
477
api/app/core/tools/custom/schema_parser.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""OpenAPI Schema解析器"""
|
||||
import json
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化解析器"""
|
||||
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
||||
|
||||
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从URL解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_url: Schema URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 验证URL格式
|
||||
parsed_url = urlparse(schema_url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
return False, {}, "无效的URL格式"
|
||||
|
||||
# 下载schema
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(schema_url) as response:
|
||||
if response.status != 200:
|
||||
return False, {}, f"HTTP错误: {response.status}"
|
||||
|
||||
content_type = response.headers.get('content-type', '').lower()
|
||||
content = await response.text()
|
||||
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, {}, "请求超时"
|
||||
except Exception as e:
|
||||
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从内容解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
content: Schema内容
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema内容失败: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析内容为字典
|
||||
|
||||
Args:
|
||||
content: 内容字符串
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 根据内容类型解析
|
||||
if 'json' in content_type:
|
||||
return json.loads(content)
|
||||
elif 'yaml' in content_type or 'yml' in content_type:
|
||||
return yaml.safe_load(content)
|
||||
else:
|
||||
# 尝试自动检测格式
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析内容失败: {e}")
|
||||
return None
|
||||
|
||||
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 检查基本结构
|
||||
if not isinstance(schema_dict, dict):
|
||||
return False, "Schema必须是JSON对象"
|
||||
|
||||
# 检查OpenAPI版本
|
||||
openapi_version = schema_dict.get("openapi")
|
||||
if not openapi_version:
|
||||
return False, "缺少openapi版本字段"
|
||||
|
||||
if openapi_version not in self.supported_versions:
|
||||
return False, f"不支持的OpenAPI版本: {openapi_version}"
|
||||
|
||||
# 检查必需字段
|
||||
required_fields = ["info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in schema_dict:
|
||||
return False, f"缺少必需字段: {field}"
|
||||
|
||||
# 验证info字段
|
||||
info = schema_dict.get("info", {})
|
||||
if not isinstance(info, dict):
|
||||
return False, "info字段必须是对象"
|
||||
|
||||
if "title" not in info:
|
||||
return False, "info.title字段是必需的"
|
||||
|
||||
# 验证paths字段
|
||||
paths = schema_dict.get("paths", {})
|
||||
if not isinstance(paths, dict):
|
||||
return False, "paths字段必须是对象"
|
||||
|
||||
# 验证至少有一个路径
|
||||
if not paths:
|
||||
return False, "至少需要定义一个API路径"
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证schema时出错: {e}"
|
||||
|
||||
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从schema提取工具信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
工具信息字典
|
||||
"""
|
||||
info = schema_dict.get("info", {})
|
||||
|
||||
return {
|
||||
"name": info.get("title", "Custom API Tool"),
|
||||
"description": info.get("description", ""),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
"servers": schema_dict.get("servers", []),
|
||||
"operations": self._extract_operations(schema_dict)
|
||||
}
|
||||
|
||||
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取API操作信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
操作信息字典
|
||||
"""
|
||||
operations = {}
|
||||
paths = schema_dict.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
continue
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
|
||||
|
||||
# 提取操作信息
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
"responses": self._extract_responses(operation),
|
||||
"tags": operation.get("tags", [])
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
参数信息字典
|
||||
"""
|
||||
parameters = {}
|
||||
|
||||
for param in operation.get("parameters", []):
|
||||
if not isinstance(param, dict):
|
||||
continue
|
||||
|
||||
param_name = param.get("name")
|
||||
if not param_name:
|
||||
continue
|
||||
|
||||
param_schema = param.get("schema", {})
|
||||
|
||||
parameters[param_name] = {
|
||||
"name": param_name,
|
||||
"in": param.get("in", "query"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"type": param_schema.get("type", "string"),
|
||||
"format": param_schema.get("format"),
|
||||
"enum": param_schema.get("enum"),
|
||||
"default": param_schema.get("default"),
|
||||
"minimum": param_schema.get("minimum"),
|
||||
"maximum": param_schema.get("maximum"),
|
||||
"pattern": param_schema.get("pattern"),
|
||||
"example": param.get("example") or param_schema.get("example")
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""提取请求体信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
请求体信息,如果没有返回None
|
||||
"""
|
||||
request_body = operation.get("requestBody")
|
||||
if not request_body:
|
||||
return None
|
||||
|
||||
content = request_body.get("content", {})
|
||||
|
||||
# 优先使用application/json
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
elif content:
|
||||
# 使用第一个可用的内容类型
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema", {})
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"description": request_body.get("description", ""),
|
||||
"required": request_body.get("required", False),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys())
|
||||
}
|
||||
|
||||
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取响应信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
响应信息字典
|
||||
"""
|
||||
responses = {}
|
||||
|
||||
for status_code, response in operation.get("responses", {}).items():
|
||||
if not isinstance(response, dict):
|
||||
continue
|
||||
|
||||
content = response.get("content", {})
|
||||
schema = None
|
||||
|
||||
# 尝试获取响应schema
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema")
|
||||
elif content:
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema")
|
||||
|
||||
responses[status_code] = {
|
||||
"description": response.get("description", ""),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys()) if content else []
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成工具参数定义
|
||||
|
||||
Args:
|
||||
operations: 操作信息字典
|
||||
|
||||
Returns:
|
||||
参数定义列表
|
||||
"""
|
||||
parameters = []
|
||||
|
||||
# 如果有多个操作,添加操作选择参数
|
||||
if len(operations) > 1:
|
||||
parameters.append({
|
||||
"name": "operation",
|
||||
"type": "string",
|
||||
"description": "要执行的操作",
|
||||
"required": True,
|
||||
"enum": list(operations.keys())
|
||||
})
|
||||
|
||||
# 收集所有参数(去重)
|
||||
all_params = {}
|
||||
|
||||
for operation_id, operation in operations.items():
|
||||
# 路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_name not in all_params:
|
||||
all_params[param_name] = {
|
||||
"name": param_name,
|
||||
"type": param_info.get("type", "string"),
|
||||
"description": param_info.get("description", ""),
|
||||
"required": param_info.get("required", False),
|
||||
"enum": param_info.get("enum"),
|
||||
"default": param_info.get("default"),
|
||||
"minimum": param_info.get("minimum"),
|
||||
"maximum": param_info.get("maximum"),
|
||||
"pattern": param_info.get("pattern")
|
||||
}
|
||||
|
||||
# 请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name not in all_params:
|
||||
all_params[prop_name] = {
|
||||
"name": prop_name,
|
||||
"type": prop_schema.get("type", "string"),
|
||||
"description": prop_schema.get("description", ""),
|
||||
"required": prop_name in schema.get("required", []),
|
||||
"enum": prop_schema.get("enum"),
|
||||
"default": prop_schema.get("default"),
|
||||
"minimum": prop_schema.get("minimum"),
|
||||
"maximum": prop_schema.get("maximum"),
|
||||
"pattern": prop_schema.get("pattern")
|
||||
}
|
||||
|
||||
# 转换为参数列表
|
||||
parameters.extend(all_params.values())
|
||||
|
||||
return parameters
|
||||
|
||||
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
params: 输入参数
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 验证路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("required", False) and param_name not in params:
|
||||
errors.append(f"缺少必需参数: {param_name}")
|
||||
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
param_type = param_info.get("type", "string")
|
||||
|
||||
# 类型验证
|
||||
if not self._validate_parameter_type(value, param_type):
|
||||
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
|
||||
|
||||
# 枚举验证
|
||||
enum_values = param_info.get("enum")
|
||||
if enum_values and value not in enum_values:
|
||||
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
|
||||
|
||||
# 验证请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
required_props = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name in required_props:
|
||||
if prop_name not in params:
|
||||
errors.append(f"缺少必需的请求体参数: {prop_name}")
|
||||
|
||||
for prop_name, value in params.items():
|
||||
if prop_name in properties:
|
||||
prop_schema = properties[prop_name]
|
||||
prop_type = prop_schema.get("type", "string")
|
||||
|
||||
if not self._validate_parameter_type(value, prop_type):
|
||||
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""验证参数类型
|
||||
|
||||
Args:
|
||||
value: 参数值
|
||||
expected_type: 期望类型
|
||||
|
||||
Returns:
|
||||
是否类型匹配
|
||||
"""
|
||||
if value is None:
|
||||
return True
|
||||
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
|
||||
expected_python_type = type_mapping.get(expected_type)
|
||||
if expected_python_type:
|
||||
return isinstance(value, expected_python_type)
|
||||
|
||||
return True
|
||||
501
api/app/core/tools/executor.py
Normal file
501
api/app/core/tools/executor.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""工具执行器 - 负责工具的实际调用和执行管理"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import ToolExecution, ExecutionStatus
|
||||
from app.core.tools.base import BaseTool, ToolResult
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ExecutionContext:
|
||||
"""执行上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_id: str,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.execution_id = execution_id
|
||||
self.tool_id = tool_id
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.timeout = timeout or 60.0 # 默认60秒超时
|
||||
self.metadata = metadata or {}
|
||||
self.started_at = datetime.now()
|
||||
self.completed_at: Optional[datetime] = None
|
||||
self.status = ExecutionStatus.PENDING
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器 - 使用langchain标准接口执行工具"""
|
||||
|
||||
def __init__(self, db: Session, registry: ToolRegistry):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
registry: 工具注册表
|
||||
"""
|
||||
self.db = db
|
||||
self.registry = registry
|
||||
self._running_executions: Dict[str, ExecutionContext] = {}
|
||||
self._execution_lock = asyncio.Lock()
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
execution_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
parameters: 工具参数
|
||||
user_id: 用户ID
|
||||
workspace_id: 工作空间ID
|
||||
execution_id: 执行ID(可选,自动生成)
|
||||
timeout: 超时时间(秒)
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 生成执行ID
|
||||
if not execution_id:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
execution_id=execution_id,
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
timeout=timeout,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
error_code="TOOL_NOT_FOUND",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 记录执行开始
|
||||
await self._record_execution_start(context, parameters)
|
||||
|
||||
# 执行工具
|
||||
result = await self._execute_with_timeout(tool, parameters, context)
|
||||
|
||||
# 记录执行完成
|
||||
await self._record_execution_complete(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
|
||||
|
||||
# 记录执行失败
|
||||
error_result = ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=time.time() - context.started_at.timestamp()
|
||||
)
|
||||
await self._record_execution_complete(context, error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
async with self._execution_lock:
|
||||
if execution_id in self._running_executions:
|
||||
del self._running_executions[execution_id]
|
||||
|
||||
async def execute_tools_batch(
|
||||
self,
|
||||
tool_executions: List[Dict[str, Any]],
|
||||
max_concurrency: int = 5
|
||||
) -> List[ToolResult]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_executions: 工具执行配置列表,每个包含tool_id和parameters
|
||||
max_concurrency: 最大并发数
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
|
||||
async with semaphore:
|
||||
return await self.execute_tool(
|
||||
tool_id=exec_config["tool_id"],
|
||||
parameters=exec_config.get("parameters", {}),
|
||||
user_id=exec_config.get("user_id"),
|
||||
workspace_id=exec_config.get("workspace_id"),
|
||||
timeout=exec_config.get("timeout"),
|
||||
metadata=exec_config.get("metadata")
|
||||
)
|
||||
|
||||
# 并发执行所有工具
|
||||
tasks = [execute_single(config) for config in tool_executions]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append(
|
||||
ToolResult.error_result(
|
||||
error=str(result),
|
||||
error_code="BATCH_EXECUTION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
"""取消工具执行
|
||||
|
||||
Args:
|
||||
execution_id: 执行ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
if execution_id not in self._running_executions:
|
||||
return False
|
||||
|
||||
context = self._running_executions[execution_id]
|
||||
context.status = ExecutionStatus.FAILED
|
||||
|
||||
# 更新数据库记录
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = ExecutionStatus.FAILED.value
|
||||
execution_record.error_message = "执行被取消"
|
||||
execution_record.completed_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"工具执行已取消: {execution_id}")
|
||||
return True
|
||||
|
||||
def get_running_executions(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的执行列表
|
||||
|
||||
Returns:
|
||||
执行信息列表
|
||||
"""
|
||||
executions = []
|
||||
for execution_id, context in self._running_executions.items():
|
||||
executions.append({
|
||||
"execution_id": execution_id,
|
||||
"tool_id": context.tool_id,
|
||||
"user_id": str(context.user_id) if context.user_id else None,
|
||||
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
|
||||
"started_at": context.started_at.isoformat(),
|
||||
"status": context.status.value,
|
||||
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
|
||||
})
|
||||
|
||||
return executions
|
||||
|
||||
async def _execute_with_timeout(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
parameters: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> ToolResult:
|
||||
"""带超时的工具执行
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
parameters: 参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
self._running_executions[context.execution_id] = context
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
|
||||
try:
|
||||
# 使用asyncio.wait_for实现超时控制
|
||||
result = await asyncio.wait_for(
|
||||
tool.safe_execute(**parameters),
|
||||
timeout=context.timeout
|
||||
)
|
||||
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
context.status = ExecutionStatus.TIMEOUT
|
||||
return ToolResult.error_result(
|
||||
error=f"工具执行超时({context.timeout}秒)",
|
||||
error_code="EXECUTION_TIMEOUT",
|
||||
execution_time=context.timeout
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
raise
|
||||
|
||||
async def _record_execution_start(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
parameters: Dict[str, Any]
|
||||
):
|
||||
"""记录执行开始"""
|
||||
try:
|
||||
execution_record = ToolExecution(
|
||||
execution_id=context.execution_id,
|
||||
tool_config_id=uuid.UUID(context.tool_id),
|
||||
status=ExecutionStatus.RUNNING.value,
|
||||
input_data=parameters,
|
||||
started_at=context.started_at,
|
||||
user_id=context.user_id,
|
||||
workspace_id=context.workspace_id
|
||||
)
|
||||
|
||||
self.db.add(execution_record)
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已创建: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
async def _record_execution_complete(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: ToolResult
|
||||
):
|
||||
"""记录执行完成"""
|
||||
try:
|
||||
context.completed_at = datetime.now()
|
||||
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == context.execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = (
|
||||
ExecutionStatus.COMPLETED.value if result.success
|
||||
else ExecutionStatus.FAILED.value
|
||||
)
|
||||
execution_record.output_data = result.data if result.success else None
|
||||
execution_record.error_message = result.error if not result.success else None
|
||||
execution_record.completed_at = context.completed_at
|
||||
execution_record.execution_time = result.execution_time
|
||||
execution_record.token_usage = result.token_usage
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已更新: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
def get_execution_history(
|
||||
self,
|
||||
tool_id: Optional[str] = None,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取执行历史
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID过滤
|
||||
user_id: 用户ID过滤
|
||||
workspace_id: 工作空间ID过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
执行历史列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolExecution).order_by(
|
||||
ToolExecution.started_at.desc()
|
||||
)
|
||||
|
||||
if tool_id:
|
||||
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
|
||||
|
||||
if user_id:
|
||||
query = query.filter(ToolExecution.user_id == user_id)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.limit(limit).all()
|
||||
|
||||
history = []
|
||||
for execution in executions:
|
||||
history.append({
|
||||
"execution_id": execution.execution_id,
|
||||
"tool_id": str(execution.tool_config_id),
|
||||
"status": execution.status,
|
||||
"started_at": execution.started_at.isoformat() if execution.started_at else None,
|
||||
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
|
||||
"execution_time": execution.execution_time,
|
||||
"user_id": str(execution.user_id) if execution.user_id else None,
|
||||
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
|
||||
"input_data": execution.input_data,
|
||||
"output_data": execution.output_data,
|
||||
"error_message": execution.error_message,
|
||||
"token_usage": execution.token_usage
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
def get_execution_statistics(
|
||||
self,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""获取执行统计信息
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.started_at >= start_date
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.all()
|
||||
|
||||
# 统计数据
|
||||
total_executions = len(executions)
|
||||
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
|
||||
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
|
||||
|
||||
# 平均执行时间
|
||||
completed_executions = [e for e in executions if e.execution_time is not None]
|
||||
avg_execution_time = (
|
||||
sum(e.execution_time for e in completed_executions) / len(completed_executions)
|
||||
if completed_executions else 0
|
||||
)
|
||||
|
||||
# 按工具统计
|
||||
tool_stats = {}
|
||||
for execution in executions:
|
||||
tool_id = str(execution.tool_config_id)
|
||||
if tool_id not in tool_stats:
|
||||
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
|
||||
|
||||
tool_stats[tool_id]["total"] += 1
|
||||
if execution.status == ExecutionStatus.COMPLETED.value:
|
||||
tool_stats[tool_id]["successful"] += 1
|
||||
elif execution.status == ExecutionStatus.FAILED.value:
|
||||
tool_stats[tool_id]["failed"] += 1
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_executions": total_executions,
|
||||
"successful_executions": successful_executions,
|
||||
"failed_executions": failed_executions,
|
||||
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
|
||||
"average_execution_time": avg_execution_time,
|
||||
"tool_statistics": tool_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
|
||||
from .mcp.client import MCPClient
|
||||
|
||||
tool_config = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.id == uuid.UUID(tool_id)
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
return {"success": False, "message": "工具不存在"}
|
||||
|
||||
if tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
|
||||
}
|
||||
except:
|
||||
await client.disconnect()
|
||||
return {"success": False, "message": "MCP功能测试失败"}
|
||||
else:
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
else:
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if tool and hasattr(tool, 'test_connection'):
|
||||
result = tool.test_connection()
|
||||
return {"success": result.get("success", False), "message": result.get("message", "")}
|
||||
return {"success": True, "message": "工具无需连接测试"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": "测试失败", "error": str(e)}
|
||||
375
api/app/core/tools/langchain_adapter.py
Normal file
375
api/app/core/tools/langchain_adapter.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Langchain适配器 - 将工具转换为langchain兼容格式"""
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import BaseTool as LangchainBaseTool
|
||||
from langchain_core.tools import ToolException
|
||||
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangchainToolWrapper(LangchainBaseTool):
|
||||
"""Langchain工具包装器"""
|
||||
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema")
|
||||
return_direct: bool = Field(False, description="是否直接返回结果")
|
||||
|
||||
# 内部工具实例
|
||||
tool_instance: BaseTool = Field(..., description="内部工具实例")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, tool_instance: BaseTool, **kwargs):
|
||||
"""初始化Langchain工具包装器
|
||||
|
||||
Args:
|
||||
tool_instance: 内部工具实例
|
||||
"""
|
||||
# 动态创建参数schema
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
|
||||
|
||||
super().__init__(
|
||||
name=tool_instance.name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
_tool_instance=tool_instance,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
run_manager=None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""同步执行工具(Langchain要求)"""
|
||||
# 由于我们的工具是异步的,这里抛出异常提示使用异步版本
|
||||
raise NotImplementedError("请使用 _arun 方法进行异步调用")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
run_manager=None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 执行内部工具
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {self.name}, 错误: {e}")
|
||||
raise ToolException(f"工具执行失败: {str(e)}")
|
||||
|
||||
|
||||
class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
|
||||
Args:
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
Langchain工具列表
|
||||
"""
|
||||
converted_tools = []
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
converted_tool = LangchainAdapter.convert_tool(tool)
|
||||
converted_tools.append(converted_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
|
||||
"""根据工具参数创建Pydantic schema
|
||||
|
||||
Args:
|
||||
parameters: 工具参数列表
|
||||
|
||||
Returns:
|
||||
Pydantic模型类
|
||||
"""
|
||||
# 构建字段定义
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
for param in parameters:
|
||||
# 确定Python类型
|
||||
python_type = LangchainAdapter._get_python_type(param.type)
|
||||
|
||||
# 处理可选参数
|
||||
if not param.required:
|
||||
python_type = Optional[python_type]
|
||||
|
||||
# 创建Field定义
|
||||
field_kwargs = {
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
field_kwargs["default"] = param.default
|
||||
elif not param.required:
|
||||
field_kwargs["default"] = None
|
||||
else:
|
||||
field_kwargs["default"] = ... # 必需字段
|
||||
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
|
||||
if param.maximum is not None:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["regex"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
|
||||
# 动态创建Pydantic模型
|
||||
schema_class = type(
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
}
|
||||
)
|
||||
|
||||
return schema_class
|
||||
|
||||
@staticmethod
|
||||
def _get_python_type(param_type: ParameterType) -> type:
|
||||
"""获取参数类型对应的Python类型
|
||||
|
||||
Args:
|
||||
param_type: 参数类型
|
||||
|
||||
Returns:
|
||||
Python类型
|
||||
"""
|
||||
type_mapping = {
|
||||
ParameterType.STRING: str,
|
||||
ParameterType.INTEGER: int,
|
||||
ParameterType.NUMBER: float,
|
||||
ParameterType.BOOLEAN: bool,
|
||||
ParameterType.ARRAY: list,
|
||||
ParameterType.OBJECT: dict
|
||||
}
|
||||
|
||||
return type_mapping.get(param_type, str)
|
||||
|
||||
@staticmethod
|
||||
def _format_result_for_langchain(result: ToolResult) -> str:
|
||||
"""将工具结果格式化为Langchain标准格式
|
||||
|
||||
Args:
|
||||
result: 工具执行结果
|
||||
|
||||
Returns:
|
||||
格式化的字符串结果
|
||||
"""
|
||||
if not result.success:
|
||||
# 错误结果
|
||||
error_info = {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
return json.dumps(error_info, ensure_ascii=False, indent=2)
|
||||
|
||||
# 成功结果
|
||||
if isinstance(result.data, str):
|
||||
# 如果数据已经是字符串,直接返回
|
||||
return result.data
|
||||
elif isinstance(result.data, (dict, list)):
|
||||
# 如果是结构化数据,转换为JSON
|
||||
return json.dumps(result.data, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
# 其他类型转换为字符串
|
||||
return str(result.data)
|
||||
|
||||
@staticmethod
|
||||
def create_tool_description(tool: BaseTool) -> Dict[str, Any]:
|
||||
"""创建工具描述(用于工具发现和文档生成)
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
工具描述字典
|
||||
"""
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"tool_type": tool.tool_type.value,
|
||||
"version": tool.version,
|
||||
"status": tool.status.value,
|
||||
"tags": tool.tags,
|
||||
"parameters": [
|
||||
{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum,
|
||||
"minimum": param.minimum,
|
||||
"maximum": param.maximum,
|
||||
"pattern": param.pattern
|
||||
}
|
||||
for param in tool.parameters
|
||||
],
|
||||
"langchain_compatible": True
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]:
|
||||
"""验证工具是否与Langchain兼容
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
(是否兼容, 问题列表)
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 检查工具名称
|
||||
if not tool.name or not isinstance(tool.name, str):
|
||||
issues.append("工具名称必须是非空字符串")
|
||||
|
||||
# 检查工具描述
|
||||
if not tool.description or not isinstance(tool.description, str):
|
||||
issues.append("工具描述必须是非空字符串")
|
||||
|
||||
# 检查参数定义
|
||||
for param in tool.parameters:
|
||||
if not param.name or not isinstance(param.name, str):
|
||||
issues.append(f"参数名称无效: {param.name}")
|
||||
|
||||
if param.type not in ParameterType:
|
||||
issues.append(f"不支持的参数类型: {param.type}")
|
||||
|
||||
if param.required and param.default is not None:
|
||||
issues.append(f"必需参数不应有默认值: {param.name}")
|
||||
|
||||
# 检查是否有execute方法
|
||||
if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')):
|
||||
issues.append("工具必须实现execute方法")
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
@staticmethod
|
||||
def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]:
|
||||
"""获取Langchain工具的OpenAPI schema
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
OpenAPI schema字典
|
||||
"""
|
||||
# 构建参数schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in tool.parameters:
|
||||
prop_schema = {
|
||||
"type": LangchainAdapter._get_openapi_type(param.type),
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.enum:
|
||||
prop_schema["enum"] = param.enum
|
||||
|
||||
if param.minimum is not None:
|
||||
prop_schema["minimum"] = param.minimum
|
||||
|
||||
if param.maximum is not None:
|
||||
prop_schema["maximum"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
prop_schema["pattern"] = param.pattern
|
||||
|
||||
if param.default is not None:
|
||||
prop_schema["default"] = param.default
|
||||
|
||||
properties[param.name] = prop_schema
|
||||
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_openapi_type(param_type: ParameterType) -> str:
|
||||
"""获取OpenAPI类型
|
||||
|
||||
Args:
|
||||
param_type: 参数类型
|
||||
|
||||
Returns:
|
||||
OpenAPI类型字符串
|
||||
"""
|
||||
type_mapping = {
|
||||
ParameterType.STRING: "string",
|
||||
ParameterType.INTEGER: "integer",
|
||||
ParameterType.NUMBER: "number",
|
||||
ParameterType.BOOLEAN: "boolean",
|
||||
ParameterType.ARRAY: "array",
|
||||
ParameterType.OBJECT: "object"
|
||||
}
|
||||
|
||||
return type_mapping.get(param_type, "string")
|
||||
12
api/app/core/tools/mcp/__init__.py
Normal file
12
api/app/core/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""MCP工具模块"""
|
||||
|
||||
from .base import MCPTool
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
"MCPTool",
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPServiceManager"
|
||||
]
|
||||
258
api/app/core/tools/mcp/base.py
Normal file
258
api/app/core/tools/mcp/base.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""MCP工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数(JSON对象)",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
626
api/app/core/tools/mcp/client.py
Normal file
626
api/app/core/tools/mcp/client.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPConnectionError(Exception):
|
||||
"""MCP连接错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
try:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
else:
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
elif self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
request_id = str(request_data["id"])
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
|
||||
# 创建Future等待响应
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
self._last_health_check = time.time()
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
604
api/app/core/tools/mcp/service_manager.py
Normal file
604
api/app/core/tools/mcp/service_manager.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
||||
|
||||
# 配置
|
||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
||||
self.max_retry_attempts = 3 # 最大重试次数
|
||||
self.retry_delay = 5 # 重试延迟(秒)
|
||||
|
||||
# 状态
|
||||
self._running = False
|
||||
self._manager_task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动服务管理器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("MCP服务管理器启动")
|
||||
|
||||
# 加载现有服务
|
||||
await self._load_existing_services()
|
||||
|
||||
# 启动管理任务
|
||||
self._manager_task = asyncio.create_task(self._management_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("MCP服务管理器停止")
|
||||
|
||||
# 停止管理任务
|
||||
if self._manager_task:
|
||||
self._manager_task.cancel()
|
||||
try:
|
||||
await self._manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有监控任务
|
||||
for task in self._monitoring_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self._monitoring_tasks:
|
||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
||||
|
||||
self._monitoring_tasks.clear()
|
||||
|
||||
# 断开所有连接
|
||||
await self.connection_pool.disconnect_all()
|
||||
|
||||
async def register_service(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
service_name: str = None
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""注册MCP服务
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
tenant_id: 租户ID
|
||||
service_name: 服务名称(可选)
|
||||
|
||||
Returns:
|
||||
(是否成功, 服务ID或错误信息, 错误详情)
|
||||
"""
|
||||
try:
|
||||
# 检查服务是否已存在
|
||||
existing_service = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.server_url == server_url
|
||||
).first()
|
||||
|
||||
if existing_service:
|
||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
||||
|
||||
# 测试连接
|
||||
try:
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if not await client.connect():
|
||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
||||
|
||||
# 获取可用工具
|
||||
available_tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
return False, "连接测试失败", str(e)
|
||||
|
||||
# 创建工具配置
|
||||
if not service_name:
|
||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
||||
|
||||
tool_config = ToolConfig(
|
||||
name=service_name,
|
||||
description=f"MCP服务 - {server_url}",
|
||||
tool_type=ToolType.MCP.value,
|
||||
tenant_id=tenant_id,
|
||||
version="1.0.0",
|
||||
config_data={
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config
|
||||
}
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建MCP特定配置
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=server_url,
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
self.db.commit()
|
||||
|
||||
service_id = str(tool_config.id)
|
||||
|
||||
# 添加到内存管理
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config,
|
||||
"tenant_id": tenant_id,
|
||||
"available_tools": tool_names,
|
||||
"status": "healthy",
|
||||
"last_health_check": time.time(),
|
||||
"retry_count": 0,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
||||
return True, service_id, None
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
||||
return False, "注册失败", str(e)
|
||||
|
||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
||||
"""注销MCP服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 从数据库删除
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if not tool_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 停止监控
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 从内存移除
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
|
||||
logger.info(f"MCP服务注销成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def update_service(
|
||||
self,
|
||||
service_id: str,
|
||||
connection_config: Dict[str, Any] = None,
|
||||
enabled: bool = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""更新MCP服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
connection_config: 新的连接配置
|
||||
enabled: 是否启用
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
tool_config = mcp_config.base_config
|
||||
|
||||
if connection_config is not None:
|
||||
mcp_config.connection_config = connection_config
|
||||
tool_config.config_data["connection_config"] = connection_config
|
||||
|
||||
if enabled is not None:
|
||||
tool_config.is_enabled = enabled
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新内存状态
|
||||
if service_id in self._services:
|
||||
if connection_config is not None:
|
||||
self._services[service_id]["connection_config"] = connection_config
|
||||
|
||||
# 如果配置有变化,重启监控
|
||||
if connection_config is not None:
|
||||
await self._restart_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务更新成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return None
|
||||
|
||||
service_info = self._services[service_id].copy()
|
||||
|
||||
# 添加实时健康检查
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
service_info["real_time_health"] = health_status
|
||||
|
||||
except Exception as e:
|
||||
service_info["real_time_health"] = {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
return service_info
|
||||
|
||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
|
||||
Returns:
|
||||
服务列表
|
||||
"""
|
||||
services = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
||||
continue
|
||||
|
||||
services.append(service_info.copy())
|
||||
|
||||
return services
|
||||
|
||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的可用工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return []
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 更新缓存的工具列表
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def call_service_tool(
|
||||
self,
|
||||
service_id: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""调用服务工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
raise ValueError(f"服务不存在: {service_id}")
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
result = await client.call_tool(tool_name, arguments, timeout)
|
||||
|
||||
# 更新服务状态为健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["last_health_check"] = time.time()
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 更新服务状态为错误
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def _load_existing_services(self):
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
tool_config = mcp_config.base_config
|
||||
service_id = str(mcp_config.id)
|
||||
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"tenant_id": tool_config.tenant_id,
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"status": mcp_config.health_status or "unknown",
|
||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
||||
"retry_count": 0,
|
||||
"created_at": tool_config.created_at.timestamp()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有服务失败: {e}")
|
||||
|
||||
async def _start_service_monitoring(self, service_id: str):
|
||||
"""启动服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._monitor_service(service_id))
|
||||
self._monitoring_tasks[service_id] = task
|
||||
|
||||
async def _stop_service_monitoring(self, service_id: str):
|
||||
"""停止服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
task = self._monitoring_tasks.pop(service_id)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _restart_service_monitoring(self, service_id: str):
|
||||
"""重启服务监控"""
|
||||
await self._stop_service_monitoring(service_id)
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
async def _monitor_service(self, service_id: str):
|
||||
"""监控单个服务"""
|
||||
try:
|
||||
while self._running and service_id in self._services:
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
# 执行健康检查
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
|
||||
if health_status["healthy"]:
|
||||
# 服务健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
# 更新工具列表
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
except Exception as e:
|
||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
||||
|
||||
else:
|
||||
# 服务不健康
|
||||
service_info["status"] = "unhealthy"
|
||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
# 更新数据库
|
||||
await self._update_service_health_in_db(service_id, health_status)
|
||||
|
||||
except Exception as e:
|
||||
# 监控异常
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
||||
|
||||
# 如果重试次数过多,暂停监控
|
||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
||||
service_info["retry_count"] = 0 # 重置重试计数
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
||||
|
||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
||||
"""更新数据库中的服务健康状态"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
else:
|
||||
mcp_config.error_message = None
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
async def _management_loop(self):
|
||||
"""管理循环 - 处理服务清理等任务"""
|
||||
try:
|
||||
while self._running:
|
||||
# 清理失效的服务
|
||||
await self._cleanup_failed_services()
|
||||
|
||||
# 等待下次循环
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"管理循环异常: {e}")
|
||||
|
||||
async def _cleanup_failed_services(self):
|
||||
"""清理长期失效的服务"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
||||
|
||||
services_to_cleanup = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
# 检查服务是否长期失效
|
||||
if (service_info["status"] in ["error", "unhealthy"] and
|
||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
||||
|
||||
services_to_cleanup.append(service_id)
|
||||
|
||||
for service_id in services_to_cleanup:
|
||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
||||
|
||||
# 停止监控但不删除数据库记录
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 标记为禁用
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if tool_config:
|
||||
tool_config.is_enabled = False
|
||||
self.db.commit()
|
||||
|
||||
# 从内存移除
|
||||
del self._services[service_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"total_services": len(self._services),
|
||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
||||
"monitoring_tasks": len(self._monitoring_tasks),
|
||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
||||
}
|
||||
436
api/app/core/tools/registry.py
Normal file
436
api/app/core/tools/registry.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .base import BaseTool, ToolInfo
|
||||
from .custom.base import CustomTool
|
||||
from .mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化工具注册表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||||
self._lock = asyncio.Lock() # 异步锁
|
||||
|
||||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||||
"""注册工具类
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
class_name: 类名(可选,默认使用类的__name__)
|
||||
"""
|
||||
class_name = class_name or tool_class.__name__
|
||||
self._tool_classes[class_name] = tool_class
|
||||
logger.info(f"工具类已注册: {class_name}")
|
||||
|
||||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||||
"""注册工具实例到系统
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否已存在
|
||||
if tenant_id:
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
# 全局工具(内置工具)
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id.is_(None),
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||||
return False
|
||||
|
||||
# 创建工具配置
|
||||
tool_config = ToolConfig(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_type=tool.tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
version=tool.version,
|
||||
tags=tool.tags,
|
||||
config_data=tool.config
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush() # 获取ID
|
||||
|
||||
# 根据工具类型创建特定配置
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
builtin_config = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool.__class__.__name__,
|
||||
parameters=tool.config.get("parameters", {})
|
||||
)
|
||||
self.db.add(builtin_config)
|
||||
|
||||
elif tool.tool_type == ToolType.CUSTOM:
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
schema_url=tool.config.get("schema_url"),
|
||||
schema_content=tool.config.get("schema_content"),
|
||||
auth_type=tool.config.get("auth_type", "none"),
|
||||
auth_config=tool.config.get("auth_config", {}),
|
||||
base_url=tool.config.get("base_url"),
|
||||
timeout=tool.config.get("timeout", 30)
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool.tool_type == ToolType.MCP:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=tool.config.get("server_url"),
|
||||
connection_config=tool.config.get("connection_config", {}),
|
||||
available_tools=tool.config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 缓存工具实例
|
||||
tool.tool_id = str(tool_config.id)
|
||||
self._tools[str(tool_config.id)] = tool
|
||||
|
||||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> bool:
|
||||
"""从系统注销工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否存在
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 检查是否有正在执行的任务
|
||||
running_executions = self.db.query(ToolExecution).filter(
|
||||
and_(
|
||||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||||
)
|
||||
).count()
|
||||
|
||||
if running_executions > 0:
|
||||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||||
return False
|
||||
|
||||
# 删除工具配置(级联删除相关记录)
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 从缓存中移除
|
||||
if tool_id in self._tools:
|
||||
del self._tools[tool_id]
|
||||
|
||||
logger.info(f"工具注销成功: {tool_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||||
"""获取工具实例
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
工具实例,如果不存在返回None
|
||||
"""
|
||||
# 先从缓存获取
|
||||
if tool_id in self._tools:
|
||||
return self._tools[tool_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 根据工具类型加载实例
|
||||
tool_instance = self._load_tool_instance(tool_config)
|
||||
if tool_instance:
|
||||
self._tools[tool_id] = tool_instance
|
||||
return tool_instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""列出工具
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
tool_type: 工具类型过滤
|
||||
status: 工具状态过滤
|
||||
tags: 标签过滤
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
|
||||
# 应用过滤条件
|
||||
if tenant_id:
|
||||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||||
query = query.filter(
|
||||
or_(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tenant_id.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
|
||||
if status == ToolStatus.ACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == True)
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == False)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||||
|
||||
tool_configs = query.all()
|
||||
|
||||
# 转换为ToolInfo
|
||||
tool_infos = []
|
||||
for config in tool_configs:
|
||||
tool_info = ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version,
|
||||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||||
)
|
||||
|
||||
# 尝试获取参数信息
|
||||
tool_instance = self.get_tool(str(config.id))
|
||||
if tool_instance:
|
||||
tool_info.parameters = tool_instance.parameters
|
||||
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
return tool_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出工具失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||||
"""更新工具状态
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
if status == ToolStatus.ACTIVE:
|
||||
tool_config.is_enabled = True
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
tool_config.is_enabled = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新缓存中的工具状态
|
||||
if tool_id in self._tools:
|
||||
self._tools[tool_id].status = status
|
||||
|
||||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||||
"""从配置加载工具实例
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置
|
||||
|
||||
Returns:
|
||||
工具实例
|
||||
"""
|
||||
try:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
# 加载内置工具
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return tool_class(str(tool_config.id), config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 加载自定义工具
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"schema_url": custom_config.schema_url,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config,
|
||||
"base_url": custom_config.base_url,
|
||||
"timeout": custom_config.timeout,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return CustomTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入自定义工具模块: {e}")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
# 加载MCP工具
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config,
|
||||
"available_tools": mcp_config.available_tools,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return MCPTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入MCP工具模块: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
if tenant_id:
|
||||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||||
|
||||
total_tools = query.count()
|
||||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||||
|
||||
# 按类型统计
|
||||
type_stats = {}
|
||||
for tool_type in ToolType:
|
||||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||||
type_stats[tool_type.value] = count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"active_tools": active_tools,
|
||||
"inactive_tools": total_tools - active_tools,
|
||||
"by_type": type_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空工具缓存"""
|
||||
self._tools.clear()
|
||||
logger.info("工具缓存已清空")
|
||||
@@ -4,8 +4,9 @@
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
# import uuid
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -15,6 +16,11 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
# from app.core.tools.executor import ToolExecutor
|
||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
||||
# from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -457,3 +463,179 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
# """获取工作流可用的工具列表
|
||||
#
|
||||
# Args:
|
||||
# workspace_id: 工作空间ID
|
||||
# user_id: 用户ID
|
||||
#
|
||||
# Returns:
|
||||
# 可用工具列表
|
||||
# """
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.warning("工具管理系统不可用")
|
||||
# return []
|
||||
#
|
||||
# try:
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具注册表
|
||||
# registry = ToolRegistry(db)
|
||||
#
|
||||
# # 注册内置工具类
|
||||
# from app.core.tools.builtin import (
|
||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
# )
|
||||
# registry.register_tool_class(DateTimeTool)
|
||||
# registry.register_tool_class(JsonTool)
|
||||
# registry.register_tool_class(BaiduSearchTool)
|
||||
# registry.register_tool_class(MinerUTool)
|
||||
# registry.register_tool_class(TextInTool)
|
||||
#
|
||||
# # 获取活跃的工具
|
||||
# import uuid
|
||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
#
|
||||
# # 转换为Langchain工具
|
||||
# langchain_tools = []
|
||||
# for tool_info in active_tools:
|
||||
# try:
|
||||
# tool_instance = registry.get_tool(tool_info.id)
|
||||
# if tool_instance:
|
||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
# langchain_tools.append(langchain_tool)
|
||||
# except Exception as e:
|
||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
#
|
||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
# return langchain_tools
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流工具失败: {e}")
|
||||
# return []
|
||||
#
|
||||
#
|
||||
# class ToolWorkflowNode:
|
||||
# """工具工作流节点 - 在工作流中执行工具"""
|
||||
#
|
||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
||||
# """初始化工具节点
|
||||
#
|
||||
# Args:
|
||||
# node_config: 节点配置
|
||||
# workflow_config: 工作流配置
|
||||
# """
|
||||
# self.node_config = node_config
|
||||
# self.workflow_config = workflow_config
|
||||
# self.tool_id = node_config.get("tool_id")
|
||||
# self.tool_parameters = node_config.get("parameters", {})
|
||||
#
|
||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
# """执行工具节点"""
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.error("工具管理系统不可用")
|
||||
# state["error"] = "工具管理系统不可用"
|
||||
# return state
|
||||
#
|
||||
# try:
|
||||
# from sqlalchemy.orm import Session
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具执行器
|
||||
# registry = ToolRegistry(db)
|
||||
# executor = ToolExecutor(db, registry)
|
||||
#
|
||||
# # 准备参数(支持变量替换)
|
||||
# parameters = self._prepare_parameters(state)
|
||||
#
|
||||
# # 执行工具
|
||||
# result = await executor.execute_tool(
|
||||
# tool_id=self.tool_id,
|
||||
# parameters=parameters,
|
||||
# user_id=uuid.UUID(state["user_id"]),
|
||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
||||
# )
|
||||
#
|
||||
# # 更新状态
|
||||
# node_id = self.node_config.get("id")
|
||||
# if result.success:
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "output": result.data,
|
||||
# "execution_time": result.execution_time,
|
||||
# "token_usage": result.token_usage
|
||||
# }
|
||||
#
|
||||
# # 更新运行时变量
|
||||
# if isinstance(result.data, dict):
|
||||
# for key, value in result.data.items():
|
||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
# else:
|
||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
# else:
|
||||
# state["error"] = result.error
|
||||
# state["error_node"] = node_id
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "error": result.error,
|
||||
# "execution_time": result.execution_time
|
||||
# }
|
||||
#
|
||||
# return state
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"工具节点执行失败: {e}")
|
||||
# state["error"] = str(e)
|
||||
# state["error_node"] = self.node_config.get("id")
|
||||
# return state
|
||||
#
|
||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
# """准备工具参数(支持变量替换)"""
|
||||
# parameters = {}
|
||||
#
|
||||
# for key, value in self.tool_parameters.items():
|
||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# # 变量替换
|
||||
# var_path = value[2:-1]
|
||||
#
|
||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
# if "." in var_path:
|
||||
# parts = var_path.split(".")
|
||||
# current = state.get("variables", {})
|
||||
#
|
||||
# for part in parts:
|
||||
# if isinstance(current, dict) and part in current:
|
||||
# current = current[part]
|
||||
# else:
|
||||
# # 尝试从运行时变量获取
|
||||
# runtime_key = ".".join(parts)
|
||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
# break
|
||||
#
|
||||
# parameters[key] = current
|
||||
# else:
|
||||
# # 简单变量
|
||||
# variables = state.get("variables", {})
|
||||
# parameters[key] = variables.get(var_path, value)
|
||||
# else:
|
||||
# parameters[key] = value
|
||||
#
|
||||
# return parameters
|
||||
#
|
||||
#
|
||||
# # 注册工具节点到NodeFactory(如果存在)
|
||||
# try:
|
||||
# from app.core.workflow.nodes import NodeFactory
|
||||
# if hasattr(NodeFactory, 'register_node_type'):
|
||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
# logger.info("工具节点已注册到工作流系统")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"注册工具节点失败: {e}")
|
||||
@@ -21,6 +21,10 @@ from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||
from .retrieval_info import RetrievalInfo
|
||||
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
|
||||
from .tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -57,5 +61,15 @@ __all__ = [
|
||||
"WorkflowNodeExecution",
|
||||
"RetrievalInfo",
|
||||
"PromptOptimizerSession",
|
||||
"PromptOptimizerSessionHistory"
|
||||
"PromptOptimizerSessionHistory",
|
||||
"RetrievalInfo",
|
||||
"ToolConfig",
|
||||
"BuiltinToolConfig",
|
||||
"CustomToolConfig",
|
||||
"MCPToolConfig",
|
||||
"ToolExecution",
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"AuthType",
|
||||
"ExecutionStatus"
|
||||
]
|
||||
|
||||
@@ -21,3 +21,6 @@ class Tenants(Base):
|
||||
|
||||
# Relationship to workspaces owned by the tenant
|
||||
owned_workspaces = relationship("Workspace", back_populates="tenant")
|
||||
|
||||
# Relationship to tool configs owned by the tenant
|
||||
tool_configs = relationship("ToolConfig", back_populates="tenant")
|
||||
|
||||
226
api/app/models/tool_model.py
Normal file
226
api/app/models/tool_model.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""工具管理相关数据模型"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ToolType(StrEnum):
|
||||
"""工具类型枚举"""
|
||||
BUILTIN = "builtin"
|
||||
CUSTOM = "custom"
|
||||
MCP = "mcp"
|
||||
|
||||
|
||||
class ToolStatus(StrEnum):
|
||||
"""工具状态枚举"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
LOADING = "loading"
|
||||
|
||||
|
||||
class AuthType(StrEnum):
|
||||
"""认证类型枚举"""
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
BEARER_TOKEN = "bearer_token"
|
||||
|
||||
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""执行状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class ToolConfig(Base):
|
||||
"""工具配置基础模型"""
|
||||
__tablename__ = "tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
description = Column(Text)
|
||||
tool_type = Column(String(50), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
|
||||
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
|
||||
|
||||
# 工具特定配置(JSON格式存储)
|
||||
config_data = Column(JSON, default=dict)
|
||||
|
||||
# 元数据
|
||||
version = Column(String(50), default="1.0.0")
|
||||
tags = Column(JSON, default=list) # 标签列表
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
|
||||
|
||||
# 关联关系
|
||||
tenant = relationship("Tenants", back_populates="tool_configs")
|
||||
executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolConfig(id={self.id}, name={self.name}, type={self.tool_type}, status={self.status})>"
|
||||
|
||||
|
||||
class BuiltinToolConfig(Base):
|
||||
"""内置工具配置模型"""
|
||||
__tablename__ = "builtin_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
tool_class = Column(String(255), nullable=False) # 工具类名
|
||||
parameters = Column(JSON, default=dict) # 工具参数配置
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
|
||||
|
||||
|
||||
class CustomToolConfig(Base):
|
||||
"""自定义工具配置模型"""
|
||||
__tablename__ = "custom_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
schema_url = Column(String(1000)) # OpenAPI schema URL
|
||||
schema_content = Column(JSON) # OpenAPI schema 内容
|
||||
|
||||
# 认证配置
|
||||
auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False)
|
||||
auth_config = Column(JSON, default=dict) # 认证配置(加密存储)
|
||||
|
||||
# API配置
|
||||
base_url = Column(String(1000)) # API基础URL
|
||||
timeout = Column(Integer, default=30) # 超时时间(秒)
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||
|
||||
|
||||
class MCPToolConfig(Base):
|
||||
"""MCP工具配置模型"""
|
||||
__tablename__ = "mcp_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
health_status = Column(String(50), default="unknown")
|
||||
error_message = Column(Text)
|
||||
|
||||
# 可用工具列表
|
||||
available_tools = Column(JSON, default=list)
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MCPToolConfig(id={self.id}, server_url={self.server_url})>"
|
||||
|
||||
|
||||
class ToolExecution(Base):
|
||||
"""工具执行记录模型"""
|
||||
__tablename__ = "tool_executions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True)
|
||||
|
||||
# 执行信息
|
||||
execution_id = Column(String(255), nullable=False, index=True) # 执行ID(可用于关联工作流等)
|
||||
status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True)
|
||||
|
||||
# 输入输出
|
||||
input_data = Column(JSON) # 输入参数
|
||||
output_data = Column(JSON) # 输出结果
|
||||
error_message = Column(Text) # 错误信息
|
||||
|
||||
# 性能指标
|
||||
started_at = Column(DateTime, nullable=False, index=True)
|
||||
completed_at = Column(DateTime)
|
||||
execution_time = Column(Float) # 执行时间(秒)
|
||||
|
||||
# Token使用情况(如果适用)
|
||||
token_usage = Column(JSON)
|
||||
|
||||
# 用户信息
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True)
|
||||
|
||||
# 关联关系
|
||||
tool_config = relationship("ToolConfig", back_populates="executions")
|
||||
user = relationship("User")
|
||||
workspace = relationship("Workspace")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolExecution(id={self.id}, status={self.status}, tool={self.tool_config_id})>"
|
||||
|
||||
|
||||
# class ToolDependency(Base):
|
||||
# """工具依赖关系模型"""
|
||||
# __tablename__ = "tool_dependencies"
|
||||
#
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
|
||||
# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
|
||||
#
|
||||
# # 依赖类型和版本要求
|
||||
# dependency_type = Column(String(50), default="required") # required, optional
|
||||
# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0"
|
||||
#
|
||||
# # 时间戳
|
||||
# created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
#
|
||||
# # 关联关系
|
||||
# tool = relationship("ToolConfig", foreign_keys=[tool_id])
|
||||
# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id])
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"<ToolDependency(tool={self.tool_id}, depends_on={self.depends_on_tool_id})>"
|
||||
|
||||
|
||||
# class PluginConfig(Base):
|
||||
# """插件配置模型"""
|
||||
# __tablename__ = "plugin_configs"
|
||||
#
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# name = Column(String(255), nullable=False, unique=True, index=True)
|
||||
# description = Column(Text)
|
||||
#
|
||||
# # 插件信息
|
||||
# plugin_path = Column(String(1000), nullable=False) # 插件文件路径
|
||||
# entry_point = Column(String(255), nullable=False) # 入口点
|
||||
# version = Column(String(50), default="1.0.0")
|
||||
#
|
||||
# # 状态
|
||||
# is_enabled = Column(Boolean, default=True, nullable=False)
|
||||
# is_loaded = Column(Boolean, default=False, nullable=False)
|
||||
# load_error = Column(Text) # 加载错误信息
|
||||
#
|
||||
# # 配置
|
||||
# config_schema = Column(JSON) # 配置schema
|
||||
# config_data = Column(JSON, default=dict) # 配置数据
|
||||
#
|
||||
# # 依赖
|
||||
# dependencies = Column(JSON, default=list) # 依赖的其他插件
|
||||
#
|
||||
# # 时间戳
|
||||
# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
# last_loaded_at = Column(DateTime)
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"<PluginConfig(id={self.id}, name={self.name}, version={self.version})>"
|
||||
@@ -14,6 +14,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories import workspace_repository, knowledge_repository
|
||||
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -328,4 +329,4 @@ def create_agent_invocation_tool(
|
||||
)
|
||||
return f"调用 Agent 失败: {str(e)}"
|
||||
|
||||
return invoke_agent
|
||||
return invoke_agent
|
||||
374
api/test_tool_system.py
Normal file
374
api/test_tool_system.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
工具管理系统基础测试脚本
|
||||
用于验证系统的基本功能是否正常
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# 测试导入
|
||||
def test_imports():
|
||||
"""测试模块导入"""
|
||||
print("测试模块导入...")
|
||||
|
||||
try:
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
|
||||
print("✓ 基础工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 基础工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
from app.core.tools.builtin.json_tool import JsonTool
|
||||
print("✓ 内置工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 内置工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
print("✓ Langchain适配器导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ Langchain适配器导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, ToolStatus
|
||||
print("✓ 工具模型导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 工具模型导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager
|
||||
print("✓ 自定义工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ 自定义工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager
|
||||
print("✓ MCP工具模块导入成功")
|
||||
except ImportError as e:
|
||||
print(f"✗ MCP工具模块导入失败: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_tool_creation():
|
||||
"""测试工具创建"""
|
||||
print("\n测试工具创建...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
|
||||
# 创建时间工具实例(全局工具)
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"timezone": "UTC"},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0",
|
||||
"tags": ["time", "utility", "builtin"]
|
||||
}
|
||||
|
||||
datetime_tool = DateTimeTool(tool_id, config)
|
||||
|
||||
# 验证工具属性
|
||||
assert datetime_tool.name == "datetime_tool"
|
||||
assert datetime_tool.tool_type.value == "builtin"
|
||||
assert len(datetime_tool.parameters) > 0
|
||||
|
||||
print("✓ 时间工具创建成功(全局工具)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 工具创建失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_tool_execution():
|
||||
"""测试工具执行"""
|
||||
print("\n测试工具执行...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
|
||||
# 创建时间工具实例
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"timezone": "UTC"},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
datetime_tool = DateTimeTool(tool_id, config)
|
||||
|
||||
# 测试获取当前时间
|
||||
result = await datetime_tool.safe_execute(operation="now")
|
||||
|
||||
assert result.success == True
|
||||
assert "datetime" in result.data
|
||||
assert result.execution_time > 0
|
||||
|
||||
print("✓ 工具执行成功")
|
||||
print(f" 执行时间: {result.execution_time:.3f}秒")
|
||||
print(f" 返回数据: {result.data}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 工具执行失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_langchain_adapter():
|
||||
"""测试Langchain适配器"""
|
||||
print("\n测试Langchain适配器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin.json_tool import JsonTool
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
|
||||
# 创建JSON工具实例
|
||||
tool_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"parameters": {"indent": 2},
|
||||
"tenant_id": None, # 全局工具
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
json_tool = JsonTool(tool_id, config)
|
||||
|
||||
# 验证Langchain兼容性
|
||||
is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool)
|
||||
|
||||
if not is_compatible:
|
||||
print(f"✗ Langchain兼容性验证失败: {issues}")
|
||||
return False
|
||||
|
||||
# 创建工具描述
|
||||
description = LangchainAdapter.create_tool_description(json_tool)
|
||||
|
||||
assert "name" in description
|
||||
assert "parameters" in description
|
||||
assert description["langchain_compatible"] == True
|
||||
|
||||
print("✓ Langchain适配器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Langchain适配器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_config_manager():
|
||||
"""测试配置管理器"""
|
||||
print("\n测试配置管理器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
|
||||
# 创建配置管理器
|
||||
config_manager = ConfigManager()
|
||||
|
||||
# 获取配置摘要
|
||||
summary = config_manager.get_config_summary()
|
||||
|
||||
assert "config_dir" in summary
|
||||
assert "total_configs" in summary
|
||||
|
||||
print("✓ 配置管理器测试成功")
|
||||
print(f" 配置目录: {summary['config_dir']}")
|
||||
print(f" 总配置数: {summary['total_configs']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 配置管理器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_schema_parser():
|
||||
"""测试OpenAPI Schema解析器"""
|
||||
print("\n测试OpenAPI Schema解析器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
|
||||
# 创建解析器
|
||||
parser = OpenAPISchemaParser()
|
||||
|
||||
# 测试简单的OpenAPI schema
|
||||
test_schema = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
"description": "测试API"
|
||||
},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "测试接口",
|
||||
"operationId": "test_operation",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = parser.validate_schema(test_schema)
|
||||
assert is_valid, f"Schema验证失败: {error_msg}"
|
||||
|
||||
# 提取工具信息
|
||||
tool_info = parser.extract_tool_info(test_schema)
|
||||
assert tool_info["name"] == "Test API"
|
||||
assert "test_operation" in tool_info["operations"]
|
||||
|
||||
print("✓ OpenAPI Schema解析器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ OpenAPI Schema解析器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_auth_manager():
|
||||
"""测试认证管理器"""
|
||||
print("\n测试认证管理器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.custom.auth_manager import AuthManager
|
||||
from app.models.tool_model import AuthType
|
||||
|
||||
# 创建认证管理器
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# 测试API Key认证配置
|
||||
api_key_config = {
|
||||
"api_key": "test-key-123",
|
||||
"key_name": "X-API-Key",
|
||||
"location": "header"
|
||||
}
|
||||
|
||||
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config)
|
||||
assert is_valid, f"API Key配置验证失败: {error_msg}"
|
||||
|
||||
# 测试Bearer Token认证配置
|
||||
bearer_config = {
|
||||
"token": "bearer-token-123"
|
||||
}
|
||||
|
||||
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config)
|
||||
assert is_valid, f"Bearer Token配置验证失败: {error_msg}"
|
||||
|
||||
# 测试认证应用
|
||||
url = "https://api.example.com/test"
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
new_url, new_headers, new_params = auth_manager.apply_authentication(
|
||||
AuthType.API_KEY, api_key_config, url, headers, params
|
||||
)
|
||||
|
||||
assert "X-API-Key" in new_headers
|
||||
assert new_headers["X-API-Key"] == "test-key-123"
|
||||
|
||||
print("✓ 认证管理器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 认证管理器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_builtin_initializer():
|
||||
"""测试内置工具初始化器"""
|
||||
print("\n测试内置工具初始化器...")
|
||||
|
||||
try:
|
||||
from app.core.tools.builtin_initializer import BuiltinToolInitializer
|
||||
|
||||
# 注意:这里不能真正初始化,因为需要数据库连接
|
||||
# 只测试类的创建和基本方法
|
||||
|
||||
# 模拟数据库会话(实际使用中需要真实的数据库连接)
|
||||
class MockDB:
|
||||
def query(self, *args):
|
||||
return self
|
||||
def filter(self, *args):
|
||||
return self
|
||||
def first(self):
|
||||
return None
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
mock_db = MockDB()
|
||||
initializer = BuiltinToolInitializer(mock_db)
|
||||
|
||||
# 测试获取内置工具状态(会返回空列表,因为没有真实数据)
|
||||
status = initializer.get_builtin_tools_status()
|
||||
assert isinstance(status, list)
|
||||
|
||||
print("✓ 内置工具初始化器测试成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 内置工具初始化器测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 50)
|
||||
print("工具管理系统基础测试")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
("模块导入", test_imports),
|
||||
("工具创建", test_tool_creation),
|
||||
("工具执行", test_tool_execution),
|
||||
("Langchain适配", test_langchain_adapter),
|
||||
("配置管理", test_config_manager),
|
||||
("Schema解析器", test_schema_parser),
|
||||
("认证管理器", test_auth_manager),
|
||||
("内置工具初始化器", test_builtin_initializer)
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(test_func):
|
||||
result = await test_func()
|
||||
else:
|
||||
result = test_func()
|
||||
|
||||
if result:
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"✗ {test_name}测试异常: {e}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(f"测试结果: {passed}/{total} 通过")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 所有基础测试通过!工具管理系统基本功能正常。")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ 部分测试失败,请检查相关模块。")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user