Merge #26 into develop from feature/20251219_xjn

feat(tool system): tool system development

* feature/20251219_xjn: (25 commits)
  feat(apikey system): tool system development
  Merge #13 into develop from fix/stream-output
  [fix]document chunk QA
  [add] workflow support stream mode
  Merge #9 into develop from fix/memory_reflection
  Merge #18 into develop from fix/memory_reflection
  [add] migration script
  fix(workflow): fix run_workflow streaming issues
  fix(prompt-optimizer): switch to built-in system prompt
  feat(workflow): add conditional branch (If-Else) node
  perf(types): add Union type declaration for workflow nodes
  fix(expression-eval): fix variable extraction issue in Jinja2 templates
  docs(samples): add config example for If-Else node
  style(workflow): update condition edge comments for conditional nodes
  style(enums): correct enum class name spelling
  refactor(workflow): unify all enum classes in one file and restructure workflow...
  feat(workflow): add import for if-else node configuration
  [add] migration script
  Merge #19 into develop from fix/memory_reflection
  Merge #21 into develop from feature/emotion-engine
  Merge #13 into develop from fix/stream-output
  Merge #21 into develop from feature/emotion-engine
  [add] migration script
  Merge remote-tracking branch 'origin/develop' into develop
  feat(tool system): tool system development

Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/26
This commit is contained in:
朱文辉
2025-12-20 16:12:25 +08:00
39 changed files with 9123 additions and 5 deletions

View File

@@ -32,6 +32,8 @@ from . import (
emotion_controller,
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
tool_execution_controller,
)
# 创建管理端 API 路由器
@@ -66,4 +68,7 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(tool_execution_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,585 @@
"""工具管理API控制器"""
import base64
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Body
from langfuse.api.core import jsonable_encoder
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, PositiveInt, field_validator
from cryptography.fernet import Fernet
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.core.tools.config_manager import ConfigManager
logger = get_business_logger()
router = APIRouter(prefix="/tools", tags=["工具管理"])
# ==================== 辅助函数 ====================
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""加密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
encrypted_params = {}
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
else:
encrypted_params[key] = value
return encrypted_params
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""解密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
decrypted_params = {}
sensitive_keys = ['api_key', 'token', 'secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
try:
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
except Exception as e:
decrypted_params[key] = value
else:
decrypted_params[key] = value
return decrypted_params
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
"""更新工具状态并返回新状态"""
if tool_config.tool_type == ToolType.BUILTIN:
if not tool_info or not tool_info.get('requires_config', False):
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
elif not builtin_config or not builtin_config.parameters:
new_status = ToolStatus.INACTIVE.value
else:
# 检查是否有必要的API密钥
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
else: # 自定义和MCP工具
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
# 更新数据库中的状态
if tool_config.status != new_status:
tool_config.status = new_status
return new_status
# ==================== 请求/响应模型 ====================
class ToolListResponse(BaseModel):
"""工具列表响应"""
id: str
name: str
description: str
tool_type: str
category: str
version: str = "1.0.0"
status: str # active inactive error loading
requires_config: bool = False
# is_configured: bool = False
class Config:
from_attributes = True
class BuiltinToolConfigRequest(BaseModel):
"""内置工具配置请求"""
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
class CustomToolCreateRequest(BaseModel):
"""自定义工具创建请求体模型,包含参数校验规则"""
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
description: str = Field(None, description="工具描述")
base_url: str = Field(None, description="工具基础URL")
schema_url: str = Field(None, description="工具Schema URL")
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容可选")
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间1-300秒默认30")
# 自定义校验当auth_type为api_key时auth_config必须包含api_key字段
@field_validator("auth_config")
def validate_auth_config(cls, v, values):
auth_type = values.data.get("auth_type")
if auth_type == "api_key" and (not v or "api_key" not in v):
raise ValueError("认证类型为api_key时auth_config必须包含api_key字段")
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
raise ValueError("认证类型为bearer_token时auth_config必须包含bearer_token字段")
return v
class MCPToolCreateRequest(BaseModel):
"""MCP工具创建请求体模型适配MCP业务特性"""
# 基础必填字段(带长度/格式校验)
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
description: str = Field(None, description="MCP工具描述")
# MCP核心字段服务端URL强制HTTP/HTTPS格式
server_url: str = Field(..., description="MCP服务端URL仅支持http/https协议")
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
connection_config: Dict[str, Any] = Field({},description="MCP连接配置如认证信息、超时、重试等默认空字典")
@field_validator("connection_config")
def validate_connection_config(cls, v):
# 示例1若包含timeout必须是1-300的整数
if "timeout" in v:
timeout = v["timeout"]
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
raise ValueError("connection_config.timeout必须是1-300的整数")
return v
# @field_validator("server_url")
# def validate_server_url_protocol(cls, v):
# if v.scheme != "https":
# raise ValueError("MCP服务端URL仅支持HTTPS协议安全要求")
# return v
# ==================== API端点 ====================
@router.get("", response_model=List[ToolListResponse])
async def list_tools(
name: Optional[str] = None,
tool_type: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取工具列表包含内置工具、自定义工具和MCP工具"""
try:
# 初始化内置工具(如果需要)
config_manager = ConfigManager()
config_manager.ensure_builtin_tools_initialized(
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
)
response_tools = []
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
tools = query.all()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
for tool_config in tools:
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
status = _update_tool_status(tool_config, builtin_config, tool_info)
else:
status = _update_tool_status(tool_config)
response_tools.append(ToolListResponse(
id=str(tool_config.id),
name=tool_config.name,
description=tool_config.description,
tool_type=tool_config.tool_type,
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
version="1.0.0",
status=status,
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
))
return response_tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}")
async def get_builtin_tool_detail(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具详情"""
try:
config_manager = ConfigManager()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id
).first()
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
is_configured = False
config_parameters = {}
if builtin_config and builtin_config.parameters:
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
# 不返回敏感信息,只返回非敏感配置
config_parameters = {k: v for k, v in builtin_config.parameters.items()
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
return {
"id": tool_config.id,
"name": tool_config.name,
"description": tool_config.description,
"category": tool_info['category'],
"status": tool_config.tool_type,
"requires_config": tool_info['requires_config'],
"is_configured": is_configured,
"config_parameters": config_parameters
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/builtin/{tool_id}/configure")
async def configure_builtin_tool(
tool_id: str,
request: BuiltinToolConfigRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""配置内置工具参数(租户级别)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 获取全局工具信息
config_manager = ConfigManager()
builtin_tools_config = config_manager.load_builtin_tools_config()
tool_info = None
for tool_key, info in builtin_tools_config.items():
if info['tool_class'] == builtin_config.tool_class:
tool_info = info
break
if not tool_info:
raise HTTPException(status_code=404, detail="工具信息不存在")
# 加密敏感参数
encrypted_params = _encrypt_sensitive_params(request.parameters)
# 更新配置
builtin_config.parameters = encrypted_params
# 更新状态
_update_tool_status(tool_config, builtin_config, tool_info)
db.commit()
return {
"success": True,
"message": f"工具 {tool_config.name} 配置成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"配置内置工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}/config")
async def get_builtin_tool_config(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具配置(用于使用)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 解密参数
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
return {
"tool_id": tool_id,
"tool_class": builtin_config.tool_class,
"name": tool_config.name,
"parameters": decrypted_params,
"status": tool_config.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具配置失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/custom")
async def create_custom_tool(
request: CustomToolCreateRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建自定义工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "custom"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.CUSTOM,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
)
db.add(tool_config)
db.flush()
# 创建CustomToolConfig记录
custom_config = CustomToolConfig(
id=tool_config.id,
base_url=request.base_url,
schema_url=request.schema_url,
schema_content=request.schema_content,
auth_type=request.auth_type,
auth_config=request.auth_config,
timeout=request.timeout
)
db.add(custom_config)
db.commit()
return {
"success": True,
"message": f"自定义工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建自定义工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mcp")
async def create_mcp_tool(
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建MCP工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "mcp"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
try:
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.MCP,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
)
db.add(tool_config)
db.flush()
# 创建MCPToolConfig记录
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=request.server_url,
connection_config=request.connection_config
)
db.add(mcp_config)
db.commit()
except SQLAlchemyError as db_e:
db.rollback()
logger.error(f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id},工具名:{request.name}: {str(db_e)}",
exc_info=True)
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id}"
f"工具名:{request.name}{str(db_e)}")
return {
"success": True,
"message": f"MCP工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建MCP工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""删除工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
return {
"success": True,
"message": f"工具 {tool.name} 删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{tool_id}")
async def update_tool(
tool_id: str,
config_data: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""更新工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许修改")
if config_data is not None:
tool.config_data = config_data
# 更新状态
_update_tool_status(tool)
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 更新成功",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/toggle")
async def toggle_tool_status(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""切换工具活跃/非活跃状态"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 在active和inactive之间切换
if tool.status == ToolStatus.ACTIVE.value:
tool.status = ToolStatus.INACTIVE.value
elif tool.status == ToolStatus.INACTIVE.value:
tool.status = ToolStatus.ACTIVE.value
else:
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"切换工具状态失败: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,430 @@
"""工具执行API控制器"""
import uuid
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
from app.core.tools.builtin import *
from app.core.logging_config import get_business_logger
logger = get_business_logger()
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
# ==================== 请求/响应模型 ====================
class ToolExecutionRequest(BaseModel):
"""工具执行请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
class BatchExecutionRequest(BaseModel):
"""批量执行请求"""
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
class ToolExecutionResponse(BaseModel):
"""工具执行响应"""
success: bool
execution_id: str
tool_id: str
data: Any = None
error: Optional[str] = None
error_code: Optional[str] = None
execution_time: float
token_usage: Optional[Dict[str, int]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChainStepRequest(BaseModel):
"""链步骤请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
condition: Optional[str] = Field(None, description="执行条件")
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
error_handling: str = Field("stop", description="错误处理策略")
class ChainExecutionRequest(BaseModel):
"""链执行请求"""
name: str = Field(..., description="链名称")
description: str = Field("", description="链描述")
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
execution_mode: str = Field("sequential", description="执行模式")
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
global_timeout: Optional[float] = Field(None, description="全局超时")
class ExecutionHistoryResponse(BaseModel):
"""执行历史响应"""
execution_id: str
tool_id: str
status: str
started_at: Optional[str]
completed_at: Optional[str]
execution_time: Optional[float]
user_id: Optional[str]
workspace_id: Optional[str]
input_data: Optional[Dict[str, Any]]
output_data: Optional[Any]
error_message: Optional[str]
token_usage: Optional[Dict[str, int]]
class ToolConnectionTestResponse(BaseModel):
"""工具连接测试响应"""
success: bool
message: str
error: Optional[str] = None
details: Optional[Dict[str, Any]] = None
# ==================== 依赖注入 ====================
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
"""获取工具注册表"""
registry = ToolRegistry(db)
# 注册内置工具类
registry.register_tool_class(DateTimeTool)
registry.register_tool_class(JsonTool)
registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool)
return registry
def get_tool_executor(
db: Session = Depends(get_db),
registry: ToolRegistry = Depends(get_tool_registry)
) -> ToolExecutor:
"""获取工具执行器"""
return ToolExecutor(db, registry)
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
"""获取链管理器"""
return ChainManager(executor)
# ==================== API端点 ====================
@router.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""执行单个工具"""
try:
# 生成执行ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 执行工具
result = await executor.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
execution_id=execution_id,
timeout=request.timeout,
metadata=request.metadata
)
return ToolExecutionResponse(
success=result.success,
execution_id=execution_id,
tool_id=request.tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
)
except Exception as e:
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=List[ToolExecutionResponse])
async def execute_tools_batch(
request: BatchExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""批量执行工具"""
try:
# 准备执行配置
execution_configs = []
execution_ids = []
for exec_request in request.executions:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution_ids.append(execution_id)
execution_configs.append({
"tool_id": exec_request.tool_id,
"parameters": exec_request.parameters,
"user_id": current_user.id,
"workspace_id": current_user.current_workspace_id,
"execution_id": execution_id,
"timeout": exec_request.timeout,
"metadata": exec_request.metadata
})
# 批量执行
results = await executor.execute_tools_batch(
execution_configs,
max_concurrency=request.max_concurrency
)
# 转换响应格式
responses = []
for i, result in enumerate(results):
responses.append(ToolExecutionResponse(
success=result.success,
execution_id=execution_ids[i],
tool_id=request.executions[i].tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
))
return responses
except Exception as e:
logger.error(f"批量执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chain", response_model=Dict[str, Any])
async def execute_tool_chain(
request: ChainExecutionRequest,
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""执行工具链"""
try:
# 转换步骤格式
steps = []
for step_request in request.steps:
step = ChainStep(
tool_id=step_request.tool_id,
parameters=step_request.parameters,
condition=step_request.condition,
output_mapping=step_request.output_mapping,
error_handling=step_request.error_handling
)
steps.append(step)
# 创建链定义
chain_definition = ChainDefinition(
name=request.name,
description=request.description,
steps=steps,
execution_mode=ChainExecutionMode(request.execution_mode),
global_timeout=request.global_timeout
)
# 注册并执行链
chain_manager.register_chain(chain_definition)
result = await chain_manager.execute_chain(
chain_name=request.name,
initial_variables=request.initial_variables
)
return result
except Exception as e:
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running", response_model=List[Dict[str, Any]])
async def get_running_executions(
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取正在运行的执行"""
try:
running_executions = executor.get_running_executions()
# 过滤当前工作空间的执行
workspace_executions = [
exec_info for exec_info in running_executions
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
]
return workspace_executions
except Exception as e:
logger.error(f"获取运行中执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
async def cancel_execution(
execution_id: str = Path(..., description="执行ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""取消工具执行"""
try:
success = await executor.cancel_execution(execution_id)
if success:
return {
"success": True,
"message": "执行已取消"
}
else:
raise HTTPException(status_code=404, detail="执行不存在或已完成")
except HTTPException:
raise
except Exception as e:
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history", response_model=List[ExecutionHistoryResponse])
async def get_execution_history(
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行历史"""
try:
history = executor.get_execution_history(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
limit=limit
)
# 转换响应格式
responses = []
for record in history:
responses.append(ExecutionHistoryResponse(
execution_id=record["execution_id"],
tool_id=record["tool_id"],
status=record["status"],
started_at=record["started_at"],
completed_at=record["completed_at"],
execution_time=record["execution_time"],
user_id=record["user_id"],
workspace_id=record["workspace_id"],
input_data=record["input_data"],
output_data=record["output_data"],
error_message=record["error_message"],
token_usage=record["token_usage"]
))
return responses
except Exception as e:
logger.error(f"获取执行历史失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/statistics", response_model=Dict[str, Any])
async def get_execution_statistics(
days: int = Query(7, ge=1, le=90, description="统计天数"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行统计"""
try:
stats = executor.get_execution_statistics(
workspace_id=current_user.current_workspace_id,
days=days
)
return {
"success": True,
"statistics": stats
}
except Exception as e:
logger.error(f"获取执行统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains/running", response_model=List[Dict[str, Any]])
async def get_running_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""获取正在运行的工具链"""
try:
running_chains = chain_manager.get_running_chains()
return running_chains
except Exception as e:
logger.error(f"获取运行中工具链失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains", response_model=List[Dict[str, Any]])
async def list_tool_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""列出工具链"""
try:
chains = chain_manager.list_chains()
return chains
except Exception as e:
logger.error(f"获取工具链列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
async def test_tool_connection(
tool_id: str = Path(..., description="工具ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""测试工具连接"""
try:
result = await executor.test_tool_connection(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id
)
return ToolConnectionTestResponse(
success=result.get("success", False),
message=result.get("message", ""),
error=result.get("error"),
details=result.get("details")
)
except Exception as e:
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
return ToolConnectionTestResponse(
success=False,
message="连接测试失败",
error=str(e)
)

View File

@@ -37,9 +37,10 @@ def require_api_key(
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
api_key_auth: ApiKeyAuth = Depends(),
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str
message: str = Query(..., description="聊天消息内容")
):
# api_key_auth 包含验证后的API Key 信息
pass

View File

@@ -157,6 +157,12 @@ class Settings:
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.

View File

@@ -0,0 +1,37 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .registry import ToolRegistry
from .executor import ToolExecutor
from .langchain_adapter import LangchainAdapter
from .config_manager import ConfigManager
from .chain_manager import ChainManager
# 可选导入,避免导入错误
try:
from .custom.base import CustomTool
except ImportError:
CustomTool = None
try:
from .mcp.base import MCPTool
except ImportError:
MCPTool = None
__all__ = [
"BaseTool",
"ToolResult",
"ToolParameter",
"ToolRegistry",
"ToolExecutor",
"LangchainAdapter",
"ConfigManager",
"ChainManager"
]
# 只有在成功导入时才添加到__all__
if CustomTool:
__all__.append("CustomTool")
if MCPTool:
__all__.append("MCPTool")

302
api/app/core/tools/base.py Normal file
View File

@@ -0,0 +1,302 @@
"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
class BaseTool(ABC):
"""所有工具的基础抽象类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化工具
Args:
tool_id: 工具ID
config: 工具配置
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
@property
@abstractmethod
def tool_type(self) -> ToolType:
"""工具类型"""
pass
@property
def version(self) -> str:
"""工具版本"""
return self.config.get("version", "1.0.0")
@property
def status(self) -> ToolStatus:
"""工具状态"""
return self._status
@status.setter
def status(self, value: ToolStatus):
"""设置工具状态"""
self._status = value
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
pass
@property
def tags(self) -> List[str]:
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数
Args:
parameters: 输入参数
Returns:
验证错误字典,空字典表示验证通过
"""
errors = {}
param_definitions = {p.name: p for p in self.parameters}
# 检查必需参数
for param_def in self.parameters:
if param_def.required and param_def.name not in parameters:
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
# 检查参数类型和约束
for param_name, param_value in parameters.items():
if param_name not in param_definitions:
continue
param_def = param_definitions[param_name]
# 类型检查
if not self._validate_parameter_type(param_value, param_def):
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
# 约束检查
constraint_error = self._validate_parameter_constraints(param_value, param_def)
if constraint_error:
errors[param_name] = constraint_error
return errors
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
"""验证参数类型"""
if value is None:
return not param_def.required
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: (int, float),
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
expected_type = type_mapping.get(param_def.type)
if expected_type:
return isinstance(value, expected_type)
return True
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
"""验证参数约束"""
if value is None:
return None
# 枚举值检查
if param_def.enum and value not in param_def.enum:
return f"Value must be one of {param_def.enum}"
# 数值范围检查
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
if param_def.minimum is not None and value < param_def.minimum:
return f"Value must be >= {param_def.minimum}"
if param_def.maximum is not None and value > param_def.maximum:
return f"Value must be <= {param_def.maximum}"
# 字符串模式检查
if param_def.type == ParameterType.STRING and param_def.pattern:
import re
if not re.match(param_def.pattern, str(value)):
return f"Value must match pattern: {param_def.pattern}"
return None
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含参数验证和异常处理)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
start_time = time.time()
try:
# 参数验证
validation_errors = self.validate_parameters(kwargs)
if validation_errors:
execution_time = time.time() - start_time
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
return ToolResult.error_result(
error=f"Parameter validation failed: {error_msg}",
error_code="VALIDATION_ERROR",
execution_time=execution_time
)
# 执行工具
result = await self.execute(**kwargs)
return result
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,17 @@
"""内置工具模块"""
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
__all__ = [
"BuiltinTool",
"DateTimeTool",
"JsonTool",
"BaiduSearchTool",
"MinerUTool",
"TextInTool"
]

View File

@@ -0,0 +1,334 @@
"""百度搜索工具 - 搜索引擎服务"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class BaiduSearchTool(BuiltinTool):
"""百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果"""
@property
def name(self) -> str:
return "baidu_search_tool"
@property
def description(self) -> str:
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
def get_required_config_parameters(self) -> List[str]:
return ["api_key"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="query",
type=ParameterType.STRING,
description="搜索关键词",
required=True
),
ToolParameter(
name="search_type",
type=ParameterType.STRING,
description="搜索类型",
required=False,
default="web",
enum=["web", "news", "image", "video"]
),
ToolParameter(
name="page_size",
type=ParameterType.INTEGER,
description="每页结果数",
required=False,
default=10,
minimum=1,
maximum=50
),
ToolParameter(
name="page_num",
type=ParameterType.INTEGER,
description="页码从1开始",
required=False,
default=1,
minimum=1,
maximum=10
),
ToolParameter(
name="safe_search",
type=ParameterType.BOOLEAN,
description="是否启用安全搜索",
required=False,
default=True
),
ToolParameter(
name="region",
type=ParameterType.STRING,
description="搜索地区",
required=False,
default="cn",
enum=["cn", "hk", "tw", "us", "jp", "kr"]
),
ToolParameter(
name="time_filter",
type=ParameterType.STRING,
description="时间过滤",
required=False,
enum=["all", "day", "week", "month", "year"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行百度搜索"""
start_time = time.time()
try:
query = kwargs.get("query")
search_type = kwargs.get("search_type", "web")
page_size = kwargs.get("page_size", 10)
page_num = kwargs.get("page_num", 1)
safe_search = kwargs.get("safe_search", True)
region = kwargs.get("region", "cn")
time_filter = kwargs.get("time_filter")
if not query:
raise ValueError("query 参数是必需的")
# 根据搜索类型调用不同的API
if search_type == "web":
result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter)
elif search_type == "news":
result = await self._news_search(query, page_size, page_num, region, time_filter)
elif search_type == "image":
result = await self._image_search(query, page_size, page_num, safe_search)
elif search_type == "video":
result = await self._video_search(query, page_size, page_num, safe_search)
else:
raise ValueError(f"不支持的搜索类型: {search_type}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="BAIDU_SEARCH_ERROR",
execution_time=execution_time
)
async def _web_search(self, query: str, page_size: int, page_num: int,
safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]:
"""网页搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "web",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _news_search(self, query: str, page_size: int, page_num: int,
region: str, time_filter: str = None) -> Dict[str, Any]:
"""新闻搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "new",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _image_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""图片搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "image",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _video_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""视频搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "video",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用百度AI搜索API"""
api_key = self.get_config_parameter("api_key")
if not api_key:
raise ValueError("百度搜索API密钥未配置")
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
return await response.json()
else:
raise Exception(f"HTTP错误: {response.status}")
async def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
if not api_key:
return {
"success": False,
"error": "API密钥未配置"
}
# 发送测试请求验证API key是否有效
test_payload = {
"messages": [{"role": "user", "content": "test"}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": 1}]
}
try:
await self._call_baidu_ai_search_api(test_payload)
return {
"success": True,
"message": "连接测试成功",
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": f"API连接失败: {str(e)}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,118 @@
"""内置工具基类"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
class BuiltinTool(BaseTool, ABC):
"""内置工具基类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化内置工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.parameters_config = config.get("parameters", {})
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
@abstractmethod
def name(self) -> str:
"""工具名称 - 子类必须实现"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述 - 子类必须实现"""
pass
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义 - 子类必须实现"""
pass
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具 - 子类必须实现
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
@property
def is_configured(self) -> bool:
"""检查工具是否已正确配置"""
required_params = self.get_required_config_parameters()
for param in required_params:
if not self.parameters_config.get(param):
return False
return True
def get_required_config_parameters(self) -> List[str]:
"""获取必需的配置参数列表
Returns:
必需配置参数名称列表
"""
return []
def get_config_parameter(self, name: str, default: Any = None) -> Any:
"""获取配置参数值
Args:
name: 参数名称
default: 默认值
Returns:
参数值
"""
return self.parameters_config.get(name, default)
def validate_configuration(self) -> tuple[bool, str]:
"""验证工具配置
Returns:
(是否有效, 错误信息)
"""
if not self.is_configured:
required_params = self.get_required_config_parameters()
missing_params = [p for p in required_params if not self.parameters_config.get(p)]
return False, f"缺少必需的配置参数: {', '.join(missing_params)}"
return True, ""
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含配置验证)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
# 首先验证配置
is_valid, error_msg = self.validate_configuration()
if not is_valid:
return ToolResult.error_result(
error=f"工具配置无效: {error_msg}",
error_code="CONFIGURATION_ERROR",
execution_time=0.0
)
# 调用父类的安全执行
return await super().safe_execute(**kwargs)

View File

@@ -0,0 +1,307 @@
"""时间工具 - 日期时间处理"""
import time
from datetime import datetime, timezone, timedelta
from typing import List
import pytz
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class DateTimeTool(BuiltinTool):
"""时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能"""
@property
def name(self) -> str:
return "datetime_tool"
@property
def description(self) -> str:
return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
),
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=False
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="calculation",
type=ParameterType.STRING,
description="时间计算表达式(如:+1d, -2h, +30m",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行时间工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
if operation == "now":
result = self._get_current_time(kwargs)
elif operation == "format":
result = self._format_datetime(kwargs)
elif operation == "convert_timezone":
result = self._convert_timezone(kwargs)
elif operation == "timestamp_to_datetime":
result = self._timestamp_to_datetime(kwargs)
elif operation == "datetime_to_timestamp":
result = self._datetime_to_timestamp(kwargs)
elif operation == "calculate":
result = self._calculate_datetime(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="DATETIME_ERROR",
execution_time=execution_time
)
def _get_current_time(self, kwargs) -> dict:
"""获取当前时间"""
timezone_str = kwargs.get("to_timezone", "UTC")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat()
}
def _format_datetime(self, kwargs) -> dict:
"""格式化时间"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
return {
"original": input_value,
"formatted": dt.strftime(output_format),
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _convert_timezone(self, kwargs) -> dict:
"""时区转换"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
from_timezone = kwargs.get("from_timezone", "UTC")
to_timezone = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置源时区
if from_timezone == "UTC":
from_tz = pytz.UTC
else:
from_tz = pytz.timezone(from_timezone)
# 设置目标时区
if to_timezone == "UTC":
to_tz = pytz.UTC
else:
to_tz = pytz.timezone(to_timezone)
# 本地化时间并转换时区
if dt.tzinfo is None:
dt = from_tz.localize(dt)
converted_dt = dt.astimezone(to_tz)
return {
"original": input_value,
"original_timezone": from_timezone,
"converted": converted_dt.strftime(output_format),
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp())
}
def _timestamp_to_datetime(self, kwargs) -> dict:
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 转换时间戳
timestamp = float(input_value)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
dt = datetime.fromtimestamp(timestamp, tz)
return {
"timestamp": timestamp,
"datetime": dt.strftime(output_format),
"timezone": timezone_str,
"iso_format": dt.isoformat()
}
def _datetime_to_timestamp(self, kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
# 本地化时间
if dt.tzinfo is None:
dt = tz.localize(dt)
return {
"datetime": input_value,
"timezone": timezone_str,
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
"""时间计算"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
calculation = kwargs.get("calculation")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
if not calculation:
raise ValueError("calculation 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
if dt.tzinfo is None:
dt = tz.localize(dt)
# 解析计算表达式
delta = self._parse_time_delta(calculation)
calculated_dt = dt + delta
return {
"original": input_value,
"calculation": calculation,
"result": calculated_dt.strftime(output_format),
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp())
}
def _parse_time_delta(self, calculation: str) -> timedelta:
"""解析时间计算表达式"""
import re
# 支持的单位d(天), h(小时), m(分钟), s(秒)
pattern = r'([+-]?\d+)([dhms])'
matches = re.findall(pattern, calculation.lower())
if not matches:
raise ValueError(f"无效的时间计算表达式: {calculation}")
total_delta = timedelta()
for value_str, unit in matches:
value = int(value_str)
if unit == 'd':
total_delta += timedelta(days=value)
elif unit == 'h':
total_delta += timedelta(hours=value)
elif unit == 'm':
total_delta += timedelta(minutes=value)
elif unit == 's':
total_delta += timedelta(seconds=value)
return total_delta

View File

@@ -0,0 +1,430 @@
"""JSON转换工具 - 数据格式转换"""
import json
import time
from typing import List, Any, Dict
import yaml
import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class JsonTool(BuiltinTool):
"""JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能"""
@property
def name(self) -> str:
return "json_tool"
@property
def description(self) -> str:
return "JSON转换工具 - 数据格式转换JSON格式化、JSON压缩、JSON验证、格式转换"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
),
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
),
ToolParameter(
name="indent",
type=ParameterType.INTEGER,
description="JSON格式化缩进空格数",
required=False,
default=2,
minimum=0,
maximum=8
),
ToolParameter(
name="ensure_ascii",
type=ParameterType.BOOLEAN,
description="是否确保ASCII编码",
required=False,
default=False
),
ToolParameter(
name="sort_keys",
type=ParameterType.BOOLEAN,
description="是否对键进行排序",
required=False,
default=False
),
ToolParameter(
name="merge_data",
type=ParameterType.STRING,
description="要合并的JSON数据用于merge操作",
required=False
),
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式用于extract操作$.user.name",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行JSON工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
input_data = kwargs.get("input_data")
if not input_data:
raise ValueError("input_data 参数是必需的")
if operation == "format":
result = self._format_json(input_data, kwargs)
elif operation == "minify":
result = self._minify_json(input_data)
elif operation == "validate":
result = self._validate_json(input_data)
elif operation == "convert":
result = self._convert_json(input_data)
elif operation == "to_yaml":
result = self._json_to_yaml(input_data)
elif operation == "from_yaml":
result = self._yaml_to_json(input_data, kwargs)
elif operation == "to_xml":
result = self._json_to_xml(input_data)
elif operation == "from_xml":
result = self._xml_to_json(input_data, kwargs)
elif operation == "merge":
result = self._merge_json(input_data, kwargs)
elif operation == "extract":
result = self._extract_json_path(input_data, kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="JSON_ERROR",
execution_time=execution_time
)
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""格式化JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
sort_keys = kwargs.get("sort_keys", False)
# 解析JSON
data = json.loads(input_data)
# 格式化输出
formatted = json.dumps(
data,
indent=indent,
ensure_ascii=ensure_ascii,
sort_keys=sort_keys,
separators=(',', ': ')
)
return {
"original_size": len(input_data),
"formatted_size": len(formatted),
"formatted_json": formatted,
"is_valid": True,
"settings": {
"indent": indent,
"ensure_ascii": ensure_ascii,
"sort_keys": sort_keys
}
}
def _minify_json(self, input_data: str) -> Dict[str, Any]:
"""压缩JSON"""
# 解析并压缩
data = json.loads(input_data)
minified = json.dumps(data, separators=(',', ':'))
return {
"original_size": len(input_data),
"minified_size": len(minified),
"compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2),
"minified_json": minified,
"is_valid": True
}
def _validate_json(self, input_data: str) -> Dict[str, Any]:
"""验证JSON"""
try:
data = json.loads(input_data)
# 统计信息
stats = self._analyze_json_structure(data)
return {
"is_valid": True,
"error": None,
"size": len(input_data),
"structure": stats
}
except json.JSONDecodeError as e:
return {
"is_valid": False,
"error": str(e),
"error_line": getattr(e, 'lineno', None),
"error_column": getattr(e, 'colno', None),
"size": len(input_data)
}
def _convert_json(self, input_data: str) -> Dict[str, Any]:
"""JSON转义"""
data = json.loads(input_data)
converted = json.dumps(data, ensure_ascii=False)
return {
"converted_json": converted,
"is_valid": True
}
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
"""JSON转YAML"""
data = json.loads(input_data)
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
return {
"original_format": "json",
"target_format": "yaml",
"original_size": len(input_data),
"converted_size": len(yaml_output),
"converted_data": yaml_output
}
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""YAML转JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
data = yaml.safe_load(input_data)
json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
return {
"original_format": "yaml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
"""JSON转XML"""
data = json.loads(input_data)
def dict_to_xml(data, root_name="root"):
"""递归转换字典为XML"""
if isinstance(data, dict):
if len(data) == 1 and not root_name == "root":
# 如果字典只有一个键,使用该键作为根元素
key, value = next(iter(data.items()))
return dict_to_xml(value, key)
root = ET.Element(root_name)
for key, value in data.items():
if isinstance(value, (dict, list)):
child = dict_to_xml(value, key)
root.append(child)
else:
child = ET.SubElement(root, key)
child.text = str(value)
return root
elif isinstance(data, list):
root = ET.Element(root_name)
for i, item in enumerate(data):
if isinstance(item, (dict, list)):
child = dict_to_xml(item, f"item_{i}")
root.append(child)
else:
child = ET.SubElement(root, f"item_{i}")
child.text = str(item)
return root
else:
root = ET.Element(root_name)
root.text = str(data)
return root
xml_element = dict_to_xml(data)
xml_string = ET.tostring(xml_element, encoding='unicode')
# 格式化XML
dom = minidom.parseString(xml_string)
formatted_xml = dom.toprettyxml(indent=" ")
# 移除空行
formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()])
return {
"original_format": "json",
"target_format": "xml",
"original_size": len(input_data),
"converted_size": len(formatted_xml),
"converted_data": formatted_xml
}
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""XML转JSON"""
indent = kwargs.get("indent", 2)
def xml_to_dict(element):
"""递归转换XML元素为字典"""
result = {}
# 处理属性
if element.attrib:
result.update(element.attrib)
# 处理文本内容
if element.text and element.text.strip():
if len(element) == 0: # 叶子节点
return element.text.strip()
else:
result['text'] = element.text.strip()
# 处理子元素
for child in element:
child_data = xml_to_dict(child)
if child.tag in result:
# 如果标签已存在,转换为列表
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag]]
result[child.tag].append(child_data)
else:
result[child.tag] = child_data
return result
root = ET.fromstring(input_data)
data = {root.tag: xml_to_dict(root)}
json_output = json.dumps(data, indent=indent, ensure_ascii=False)
return {
"original_format": "xml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""合并JSON"""
merge_data = kwargs.get("merge_data")
if not merge_data:
raise ValueError("merge_data 参数是必需的")
data1 = json.loads(input_data)
data2 = json.loads(merge_data)
def deep_merge(dict1, dict2):
"""深度合并字典"""
result = dict1.copy()
for key, value in dict2.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
if isinstance(data1, dict) and isinstance(data2, dict):
merged = deep_merge(data1, data2)
elif isinstance(data1, list) and isinstance(data2, list):
merged = data1 + data2
else:
raise ValueError("无法合并不同类型的数据")
merged_json = json.dumps(merged, indent=2, ensure_ascii=False)
return {
"operation": "merge",
"original_size": len(input_data),
"merge_size": len(merge_data),
"result_size": len(merged_json),
"merged_data": merged_json
}
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取JSON路径"""
json_path = kwargs.get("json_path")
if not json_path:
raise ValueError("json_path 参数是必需的")
data = json.loads(input_data)
# 简单的JSONPath实现支持基本的点号路径
try:
result = data
if json_path.startswith('$.'):
path_parts = json_path[2:].split('.')
else:
path_parts = json_path.split('.')
for part in path_parts:
if part.isdigit():
result = result[int(part)]
else:
result = result[part]
extracted_json = json.dumps(result, indent=2, ensure_ascii=False)
return {
"operation": "extract",
"json_path": json_path,
"found": True,
"extracted_data": extracted_json,
"data_type": type(result).__name__
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "extract",
"json_path": json_path,
"found": False,
"error": str(e),
"extracted_data": None
}
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
"""分析JSON结构"""
if isinstance(data, dict):
return {
"type": "object",
"keys": len(data),
"depth": depth,
"children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()}
}
elif isinstance(data, list):
return {
"type": "array",
"length": len(data),
"depth": depth,
"item_types": list(set(type(item).__name__ for item in data))
}
else:
return {
"type": type(data).__name__,
"depth": depth,
"value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data)
}

View File

@@ -0,0 +1,327 @@
"""MinerU PDF解析工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class MinerUTool(BuiltinTool):
"""MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能"""
@property
def name(self) -> str:
return "mineru_tool"
@property
def description(self) -> str:
return "MinerU - PDF解析工具PDF解析、表格提取、图片识别、文本提取"
def get_required_config_parameters(self) -> List[str]:
return ["api_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"]
),
ToolParameter(
name="file_content",
type=ParameterType.STRING,
description="PDF文件内容Base64编码",
required=False
),
ToolParameter(
name="file_url",
type=ParameterType.STRING,
description="PDF文件URL",
required=False
),
ToolParameter(
name="parse_mode",
type=ParameterType.STRING,
description="解析模式",
required=False,
default="auto",
enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"]
),
ToolParameter(
name="extract_images",
type=ParameterType.BOOLEAN,
description="是否提取图片",
required=False,
default=True
),
ToolParameter(
name="extract_tables",
type=ParameterType.BOOLEAN,
description="是否提取表格",
required=False,
default=True
),
ToolParameter(
name="page_range",
type=ParameterType.STRING,
description="页面范围1-5, 1,3,5",
required=False
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="json",
enum=["json", "markdown", "html", "text"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行MinerU PDF解析"""
start_time = time.time()
try:
operation = kwargs.get("operation")
file_content = kwargs.get("file_content")
file_url = kwargs.get("file_url")
if not file_content and not file_url:
raise ValueError("必须提供 file_content 或 file_url 参数")
if operation == "parse_pdf":
result = await self._parse_pdf(kwargs)
elif operation == "extract_text":
result = await self._extract_text(kwargs)
elif operation == "extract_tables":
result = await self._extract_tables(kwargs)
elif operation == "extract_images":
result = await self._extract_images(kwargs)
elif operation == "analyze_layout":
result = await self._analyze_layout(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MINERU_ERROR",
execution_time=execution_time
)
async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""完整PDF解析"""
parse_mode = kwargs.get("parse_mode", "auto")
extract_images = kwargs.get("extract_images", True)
extract_tables = kwargs.get("extract_tables", True)
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
# 构建请求参数
request_data = {
"parse_mode": parse_mode,
"extract_images": extract_images,
"extract_tables": extract_tables,
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
# 添加文件数据
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
# 调用MinerU API
result = await self._call_mineru_api("parse", request_data)
return {
"operation": "parse_pdf",
"parse_mode": parse_mode,
"total_pages": result.get("total_pages", 0),
"processed_pages": result.get("processed_pages", 0),
"text_content": result.get("text_content", ""),
"tables": result.get("tables", []),
"images": result.get("images", []),
"layout_info": result.get("layout_info", {}),
"metadata": result.get("metadata", {}),
"processing_time": result.get("processing_time", 0)
}
async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取文本"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "text")
request_data = {
"operation": "extract_text",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_text", request_data)
return {
"operation": "extract_text",
"total_pages": result.get("total_pages", 0),
"text_content": result.get("text_content", ""),
"word_count": len(result.get("text_content", "").split()),
"character_count": len(result.get("text_content", "")),
"pages_text": result.get("pages_text", [])
}
async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取表格"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
request_data = {
"operation": "extract_tables",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_tables", request_data)
return {
"operation": "extract_tables",
"total_tables": result.get("total_tables", 0),
"tables": result.get("tables", []),
"table_locations": result.get("table_locations", [])
}
async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取图片"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "extract_images"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_images", request_data)
return {
"operation": "extract_images",
"total_images": result.get("total_images", 0),
"images": result.get("images", []),
"image_locations": result.get("image_locations", [])
}
async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""分析布局"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "analyze_layout"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("analyze_layout", request_data)
return {
"operation": "analyze_layout",
"layout_info": result.get("layout_info", {}),
"page_layouts": result.get("page_layouts", []),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", [])
}
async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用MinerU API"""
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
timeout_seconds = self.get_config_parameter("timeout", 60)
if not api_key or not api_url:
raise ValueError("MinerU API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("success", True):
return result.get("data", result)
else:
raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
if not api_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,401 @@
"""TextIn OCR文字识别工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class TextInTool(BuiltinTool):
"""TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别"""
@property
def name(self) -> str:
return "textin_tool"
@property
def description(self) -> str:
return "TextIn - OCR文字识别通用OCR、手写识别、多语言支持、高精度识别"
def get_required_config_parameters(self) -> List[str]:
return ["app_id", "secret_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="image_content",
type=ParameterType.STRING,
description="图片内容Base64编码",
required=False
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="图片URL",
required=False
),
ToolParameter(
name="language",
type=ParameterType.STRING,
description="识别语言",
required=False,
default="auto",
enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"]
),
ToolParameter(
name="recognition_mode",
type=ParameterType.STRING,
description="识别模式",
required=False,
default="general",
enum=["general", "accurate", "handwriting", "formula", "table", "document"]
),
ToolParameter(
name="return_location",
type=ParameterType.BOOLEAN,
description="是否返回文字位置信息",
required=False,
default=False
),
ToolParameter(
name="return_confidence",
type=ParameterType.BOOLEAN,
description="是否返回置信度",
required=False,
default=True
),
ToolParameter(
name="merge_lines",
type=ParameterType.BOOLEAN,
description="是否合并行",
required=False,
default=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="text",
enum=["text", "json", "structured"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行TextIn OCR识别"""
start_time = time.time()
try:
image_content = kwargs.get("image_content")
image_url = kwargs.get("image_url")
if not image_content and not image_url:
raise ValueError("必须提供 image_content 或 image_url 参数")
language = kwargs.get("language", "auto")
recognition_mode = kwargs.get("recognition_mode", "general")
return_location = kwargs.get("return_location", False)
return_confidence = kwargs.get("return_confidence", True)
merge_lines = kwargs.get("merge_lines", True)
output_format = kwargs.get("output_format", "text")
# 根据识别模式调用不同的API
if recognition_mode == "general":
result = await self._general_ocr(kwargs)
elif recognition_mode == "accurate":
result = await self._accurate_ocr(kwargs)
elif recognition_mode == "handwriting":
result = await self._handwriting_ocr(kwargs)
elif recognition_mode == "formula":
result = await self._formula_ocr(kwargs)
elif recognition_mode == "table":
result = await self._table_ocr(kwargs)
elif recognition_mode == "document":
result = await self._document_ocr(kwargs)
else:
raise ValueError(f"不支持的识别模式: {recognition_mode}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="TEXTIN_ERROR",
execution_time=execution_time
)
async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""通用OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("general_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""高精度OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("accurate_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""手写体识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("handwriting_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""公式识别"""
request_data = {
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_latex": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("formula_ocr", request_data)
return self._format_formula_result(result, kwargs.get("output_format", "text"))
async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""表格识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_excel": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("table_ocr", request_data)
return self._format_table_result(result, kwargs.get("output_format", "text"))
async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""文档识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"layout_analysis": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("document_ocr", request_data)
return self._format_document_result(result, kwargs.get("output_format", "text"))
def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None:
"""格式化OCR结果"""
lines = result.get("lines", [])
if output_format == "text":
text_content = "\n".join([line.get("text", "") for line in lines])
return {
"recognition_mode": "ocr",
"text_content": text_content,
"line_count": len(lines),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "json":
return {
"recognition_mode": "ocr",
"lines": lines,
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "structured":
return {
"recognition_mode": "ocr",
"text_content": "\n".join([line.get("text", "") for line in lines]),
"structured_data": {
"lines": lines,
"paragraphs": self._group_lines_to_paragraphs(lines),
"statistics": {
"line_count": len(lines),
"word_count": sum(len(line.get("text", "").split()) for line in lines),
"character_count": sum(len(line.get("text", "")) for line in lines)
}
},
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化公式识别结果"""
formulas = result.get("formulas", [])
return {
"recognition_mode": "formula",
"formula_count": len(formulas),
"formulas": formulas,
"latex_content": "\n".join([f.get("latex", "") for f in formulas]),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化表格识别结果"""
tables = result.get("tables", [])
return {
"recognition_mode": "table",
"table_count": len(tables),
"tables": tables,
"excel_data": result.get("excel_data"),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化文档识别结果"""
return {
"recognition_mode": "document",
"layout_info": result.get("layout_info", {}),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", []),
"full_text": result.get("full_text", ""),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将行分组为段落"""
paragraphs = []
current_paragraph = []
for line in lines:
text = line.get("text", "").strip()
if text:
current_paragraph.append(line)
else:
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
current_paragraph = []
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
return paragraphs
async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用TextIn API"""
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
raise ValueError("TextIn API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"X-App-Id": app_id,
"X-Secret-Key": secret_key,
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("code") == 200:
return result.get("data", result)
else:
raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"app_id": app_id,
"secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,485 @@
"""工具链管理器 - 支持langchain的工具链模式"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from app.core.tools.base import ToolResult
from app.core.tools.executor import ToolExecutor
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ChainExecutionMode(str, Enum):
"""链执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
CONDITIONAL = "conditional" # 条件执行
@dataclass
class ChainStep:
"""链步骤定义"""
tool_id: str
parameters: Dict[str, Any]
condition: Optional[str] = None # 执行条件
output_mapping: Optional[Dict[str, str]] = None # 输出映射
error_handling: str = "stop" # 错误处理stop, continue, retry
@dataclass
class ChainDefinition:
"""工具链定义"""
name: str
description: str
steps: List[ChainStep]
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
global_timeout: Optional[float] = None
retry_policy: Optional[Dict[str, Any]] = None
class ChainExecutionContext:
"""链执行上下文"""
def __init__(self, chain_id: str):
self.chain_id = chain_id
self.variables: Dict[str, Any] = {}
self.step_results: Dict[int, ToolResult] = {}
self.current_step = 0
self.is_completed = False
self.is_failed = False
self.error_message: Optional[str] = None
class ChainManager:
"""工具链管理器 - 支持langchain的工具链模式"""
def __init__(self, executor: ToolExecutor):
"""初始化工具链管理器
Args:
executor: 工具执行器
"""
self.executor = executor
self._chains: Dict[str, ChainDefinition] = {}
self._running_chains: Dict[str, ChainExecutionContext] = {}
def register_chain(self, chain: ChainDefinition) -> bool:
"""注册工具链
Args:
chain: 工具链定义
Returns:
注册是否成功
"""
try:
# 验证工具链定义
validation_result = self._validate_chain(chain)
if not validation_result[0]:
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
return False
self._chains[chain.name] = chain
logger.info(f"工具链注册成功: {chain.name}")
return True
except Exception as e:
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
return False
def unregister_chain(self, chain_name: str) -> bool:
"""注销工具链
Args:
chain_name: 工具链名称
Returns:
注销是否成功
"""
if chain_name in self._chains:
del self._chains[chain_name]
logger.info(f"工具链注销成功: {chain_name}")
return True
return False
def list_chains(self) -> List[Dict[str, Any]]:
"""列出所有工具链
Returns:
工具链信息列表
"""
chains = []
for name, chain in self._chains.items():
chains.append({
"name": name,
"description": chain.description,
"step_count": len(chain.steps),
"execution_mode": chain.execution_mode.value,
"global_timeout": chain.global_timeout
})
return chains
async def execute_chain(
self,
chain_name: str,
initial_variables: Optional[Dict[str, Any]] = None,
chain_id: Optional[str] = None
) -> Dict[str, Any] | None:
"""执行工具链
Args:
chain_name: 工具链名称
initial_variables: 初始变量
chain_id: 链执行ID可选
Returns:
执行结果
"""
if chain_name not in self._chains:
return {
"success": False,
"error": f"工具链不存在: {chain_name}",
"chain_id": chain_id
}
chain = self._chains[chain_name]
# 生成链ID
if not chain_id:
import uuid
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ChainExecutionContext(chain_id)
context.variables = initial_variables or {}
self._running_chains[chain_id] = context
try:
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
# 根据执行模式执行
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
result = await self._execute_sequential(chain, context)
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
result = await self._execute_parallel(chain, context)
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
result = await self._execute_conditional(chain, context)
else:
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
return result
except Exception as e:
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
return {
"success": False,
"error": str(e),
"chain_id": chain_id,
"completed_steps": context.current_step,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
finally:
# 清理执行上下文
if chain_id in self._running_chains:
del self._running_chains[chain_id]
async def _execute_sequential(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""顺序执行工具链"""
for i, step in enumerate(chain.steps):
context.current_step = i
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
logger.debug(f"跳过步骤 {i}: 条件不满足")
continue
# 准备参数
parameters = self._prepare_parameters(step.parameters, context)
# 执行工具
try:
result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = result
# 处理输出映射
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 处理执行失败
if not result.success:
if step.error_handling == "stop":
context.is_failed = True
context.error_message = result.error
break
elif step.error_handling == "continue":
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
continue
elif step.error_handling == "retry":
# 简单重试逻辑
retry_result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = retry_result
if not retry_result.success and step.error_handling == "stop":
context.is_failed = True
context.error_message = retry_result.error
break
except Exception as e:
logger.error(f"步骤 {i} 执行异常: {e}")
if step.error_handling == "stop":
context.is_failed = True
context.error_message = str(e)
break
context.is_completed = not context.is_failed
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": context.current_step + 1,
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_parallel(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""并行执行工具链"""
# 准备所有步骤的执行配置
execution_configs = []
for i, step in enumerate(chain.steps):
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
continue
parameters = self._prepare_parameters(step.parameters, context)
execution_configs.append({
"step_index": i,
"tool_id": step.tool_id,
"parameters": parameters
})
# 并行执行所有步骤
try:
results = await self.executor.execute_tools_batch(execution_configs)
# 处理结果
for i, result in enumerate(results):
step_index = execution_configs[i]["step_index"]
context.step_results[step_index] = result
# 处理输出映射
step = chain.steps[step_index]
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 检查是否有失败的步骤
failed_steps = [i for i, result in context.step_results.items() if not result.success]
context.is_completed = len(failed_steps) == 0
if failed_steps:
context.error_message = f"步骤 {failed_steps} 执行失败"
except Exception as e:
context.is_failed = True
context.error_message = str(e)
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": len(context.step_results),
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_conditional(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""条件执行工具链"""
# 条件执行类似于顺序执行,但更严格地检查条件
return await self._execute_sequential(chain, context)
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
"""验证工具链定义
Args:
chain: 工具链定义
Returns:
(是否有效, 错误信息)
"""
if not chain.name:
return False, "工具链名称不能为空"
if not chain.steps:
return False, "工具链必须包含至少一个步骤"
for i, step in enumerate(chain.steps):
if not step.tool_id:
return False, f"步骤 {i} 缺少工具ID"
if step.error_handling not in ["stop", "continue", "retry"]:
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
return True, None
def _prepare_parameters(
self,
parameters: Dict[str, Any],
context: ChainExecutionContext
) -> Dict[str, Any]:
"""准备参数(支持变量替换)
Args:
parameters: 原始参数
context: 执行上下文
Returns:
处理后的参数
"""
prepared = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_name = value[2:-1]
if var_name in context.variables:
prepared[key] = context.variables[var_name]
else:
prepared[key] = value # 保持原值
else:
prepared[key] = value
return prepared
def _evaluate_condition(
self,
condition: str,
context: ChainExecutionContext
) -> bool:
"""评估执行条件
Args:
condition: 条件表达式
context: 执行上下文
Returns:
条件是否满足
"""
try:
# 简单的条件评估(可以扩展为更复杂的表达式解析)
# 支持格式variable == value, variable != value, variable > value 等
if "==" in condition:
var_name, expected_value = condition.split("==", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) == expected_value
elif "!=" in condition:
var_name, expected_value = condition.split("!=", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) != expected_value
elif condition in context.variables:
# 简单的布尔检查
return bool(context.variables[condition])
else:
# 默认为真
return True
except Exception as e:
logger.error(f"条件评估失败: {condition}, 错误: {e}")
return False
def _apply_output_mapping(
self,
mapping: Dict[str, str],
output_data: Any,
context: ChainExecutionContext
):
"""应用输出映射
Args:
mapping: 输出映射配置
output_data: 输出数据
context: 执行上下文
"""
try:
if isinstance(output_data, dict):
for source_key, target_var in mapping.items():
if source_key in output_data:
context.variables[target_var] = output_data[source_key]
else:
# 如果输出不是字典,将整个输出映射到指定变量
if "result" in mapping:
context.variables[mapping["result"]] = output_data
except Exception as e:
logger.error(f"输出映射失败: {e}")
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
"""序列化工具结果
Args:
result: 工具结果
Returns:
序列化的结果
"""
return {
"success": result.success,
"data": result.data,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time,
"token_usage": result.token_usage,
"metadata": result.metadata
}
def get_running_chains(self) -> List[Dict[str, Any]]:
"""获取正在运行的工具链
Returns:
运行中的工具链列表
"""
chains = []
for chain_id, context in self._running_chains.items():
chains.append({
"chain_id": chain_id,
"current_step": context.current_step,
"is_completed": context.is_completed,
"is_failed": context.is_failed,
"variables_count": len(context.variables),
"completed_steps": len(context.step_results)
})
return chains

View File

@@ -0,0 +1,264 @@
"""工具配置管理器 - 管理工具配置的加载和验证"""
import json
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, ValidationError
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ToolConfigSchema(BaseModel):
"""工具配置基础Schema"""
name: str
description: str
tool_type: str
version: str = "1.0.0"
enabled: bool = True
parameters: Dict[str, Any] = {}
tags: list[str] = []
class Config:
extra = "allow"
class BuiltinToolConfigSchema(ToolConfigSchema):
"""内置工具配置Schema"""
tool_class: str
tool_type: str = "builtin"
class CustomToolConfigSchema(ToolConfigSchema):
"""自定义工具配置Schema"""
schema_url: Optional[str] = None
schema_content: Optional[Dict[str, Any]] = None
auth_type: str = "none"
auth_config: Dict[str, Any] = {}
base_url: Optional[str] = None
timeout: int = 30
tool_type: str = "custom"
class MCPToolConfigSchema(ToolConfigSchema):
"""MCP工具配置Schema"""
server_url: str
connection_config: Dict[str, Any] = {}
available_tools: list[str] = []
tool_type: str = "mcp"
class ConfigManager:
"""工具配置管理器"""
def __init__(self, config_dir: Optional[str] = None):
"""初始化配置管理器
Args:
config_dir: 配置文件目录,默认使用系统配置
"""
self.config_dir = Path(config_dir or self._get_default_config_dir())
self.config_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
def _get_default_config_dir(self) -> str:
"""获取默认配置目录"""
# 获取tools目录下的configs子目录
tools_dir = Path(__file__).parent
return str(tools_dir / "configs")
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
"""加载内置工具配置
Returns:
内置工具配置字典
"""
configs = {}
builtin_dir = self.config_dir / "builtin"
if not builtin_dir.exists():
logger.info("内置工具配置目录不存在,创建默认配置")
self._create_default_builtin_configs(builtin_dir)
for config_file in builtin_dir.glob("*.json"):
try:
config_data = self._load_config_file(config_file)
config = BuiltinToolConfigSchema(**config_data)
configs[config.name] = config
logger.debug(f"加载内置工具配置: {config.name}")
except Exception as e:
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
return configs
def load_builtin_tools_config(self) -> Dict[str, Any]:
"""加载全局内置工具配置(兼容原有接口)
Returns:
内置工具配置字典
"""
config_file = self.config_dir / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
"""确保内置工具已初始化到数据库
Args:
tenant_id: 租户ID
db_session: 数据库会话
tool_config_model: ToolConfig模型类
builtin_tool_config_model: BuiltinToolConfig模型类
tool_type_enum: ToolType枚举
tool_status_enum: ToolStatus枚举
"""
# 检查是否已初始化
existing_count = db_session.query(tool_config_model).filter(
tool_config_model.tenant_id == tenant_id,
tool_config_model.tool_type == tool_type_enum.BUILTIN
).count()
if existing_count > 0:
return # 已初始化
# 加载全局配置
builtin_tools = self.load_builtin_tools_config()
# 为租户创建内置工具记录
for tool_key, tool_info in builtin_tools.items():
# 设置初始状态
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
tool_config = tool_config_model(
name=tool_info['name'],
description=tool_info['description'],
tool_type=tool_type_enum.BUILTIN,
tenant_id=tenant_id,
status=initial_status
)
db_session.add(tool_config)
db_session.flush()
builtin_config = builtin_tool_config_model(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={}
)
db_session.add(builtin_config)
db_session.commit()
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
"""保存工具配置
Args:
config: 工具配置
tool_type: 工具类型
Returns:
保存是否成功
"""
try:
config_dir = self.config_dir / tool_type
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{config.name}.json"
config_data = config.model_dump()
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2, ensure_ascii=False)
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
return True
except Exception as e:
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
return False
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
"""删除工具配置
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
删除是否成功
"""
try:
config_file = self.config_dir / tool_type / f"{tool_name}.json"
if config_file.exists():
config_file.unlink()
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
return True
else:
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
return False
except Exception as e:
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
return False
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
"""验证工具配置
Args:
config_data: 配置数据
tool_type: 工具类型
Returns:
(是否有效, 错误信息)
"""
try:
schema_map = {
"builtin": BuiltinToolConfigSchema,
"custom": CustomToolConfigSchema,
"mcp": MCPToolConfigSchema
}
schema_class = schema_map.get(tool_type)
if not schema_class:
return False, f"不支持的工具类型: {tool_type}"
# 验证配置
schema_class(**config_data)
return True, None
except ValidationError as e:
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
return False, f"配置验证失败: {error_msg}"
except Exception as e:
return False, f"配置验证异常: {str(e)}"
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
"""加载配置文件
Args:
config_file: 配置文件路径
Returns:
配置数据字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
raise
def _create_default_builtin_configs(self, builtin_dir: Path):
"""创建默认内置工具配置
Args:
builtin_dir: 内置工具配置目录
"""
builtin_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
# 配置文件已经通过其他方式创建,这里只需要确保目录存在

View File

@@ -0,0 +1,14 @@
{
"name": "baidu_search_tool",
"description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能",
"tool_type": "builtin",
"tool_class": "BaiduSearchTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"secret_key": "",
"search_type": "web"
},
"tags": ["search", "web", "baidu", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "datetime_tool",
"description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算",
"tool_type": "builtin",
"tool_class": "DateTimeTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"timezone": "UTC"
},
"tags": ["time", "utility", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "json_tool",
"description": "JSON工具 - 数据格式处理提供JSON格式化、压缩、验证、格式转换",
"tool_type": "builtin",
"tool_class": "JsonTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"indent": 2
},
"tags": ["json", "data", "utility", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "mineru_tool",
"description": "MinerU PDF解析工具 - 文档处理提供PDF解析、表格提取、图片识别、文本提取功能",
"tool_type": "builtin",
"tool_class": "MinerUTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"parse_mode": "auto",
"timeout": 60
},
"tags": ["pdf", "document", "ocr", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "textin_tool",
"description": "TextIn OCR工具 - 图像识别提供通用OCR、手写识别、多语言支持功能",
"tool_type": "builtin",
"tool_class": "TextInTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"app_id": "",
"language": "auto",
"recognition_mode": "general"
},
"tags": ["ocr", "image", "text", "builtin"]
}

View File

@@ -0,0 +1,60 @@
{
"datetime": {
"name": "时间工具",
"description": "获取当前时间、日期计算",
"tool_class": "DateTimeTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"json_converter": {
"name": "JSON转换工具",
"description": "JSON数据格式化和转换",
"tool_class": "JsonTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"baidu_search": {
"name": "百度搜索",
"description": "百度网页搜索服务",
"tool_class": "BaiduSearchTool",
"category": "search",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
}
},
"mineru": {
"name": "MinerU",
"description": "PDF文档解析工具",
"tool_class": "MinerUTool",
"category": "document",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true},
"base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"}
}
},
"textin": {
"name": "TextIn",
"description": "OCR文字识别服务",
"tool_class": "TextInTool",
"category": "ocr",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
}
}
}

View File

@@ -0,0 +1,11 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
__all__ = [
"CustomTool",
"OpenAPISchemaParser",
"AuthManager"
]

View File

@@ -0,0 +1,525 @@
"""认证管理器 - 处理自定义工具的认证配置"""
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
from app.models.tool_model import AuthType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AuthManager:
"""认证管理器 - 支持多种认证方式"""
def __init__(self):
"""初始化认证管理器"""
self.supported_auth_types = [
AuthType.NONE,
AuthType.API_KEY,
AuthType.BEARER_TOKEN
]
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
try:
if auth_type not in self.supported_auth_types:
return False, f"不支持的认证类型: {auth_type}"
if auth_type == AuthType.NONE:
return True, ""
elif auth_type == AuthType.API_KEY:
return self._validate_api_key_config(auth_config)
elif auth_type == AuthType.BEARER_TOKEN:
return self._validate_bearer_token_config(auth_config)
return False, "未知的认证类型"
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
api_key = auth_config.get("api_key")
if not api_key:
return False, "API Key不能为空"
if not isinstance(api_key, str):
return False, "API Key必须是字符串"
# 验证key名称
key_name = auth_config.get("key_name", "X-API-Key")
if not isinstance(key_name, str):
return False, "API Key名称必须是字符串"
# 验证位置
key_location = auth_config.get("location", "header")
if key_location not in ["header", "query", "cookie"]:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
token = auth_config.get("token")
if not token:
return False, "Bearer Token不能为空"
if not isinstance(token, str):
return False, "Bearer Token必须是字符串"
return True, ""
def apply_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用认证到请求
Args:
auth_type: 认证类型
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
try:
if auth_type == AuthType.NONE:
return url, headers, params
elif auth_type == AuthType.API_KEY:
return self._apply_api_key_auth(auth_config, url, headers, params)
elif auth_type == AuthType.BEARER_TOKEN:
return self._apply_bearer_token_auth(auth_config, url, headers, params)
else:
logger.warning(f"不支持的认证类型: {auth_type}")
return url, headers, params
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用API Key认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
location = auth_config.get("location", "header")
if location == "header":
headers[key_name] = api_key
elif location == "query":
# 添加到URL查询参数
separator = "&" if "?" in url else "?"
encoded_key = quote(str(api_key))
url += f"{separator}{key_name}={encoded_key}"
elif location == "cookie":
# 添加到Cookie头
cookie_value = f"{key_name}={api_key}"
if "Cookie" in headers:
headers["Cookie"] += f"; {cookie_value}"
else:
headers["Cookie"] = cookie_value
return url, headers, params
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用Bearer Token认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
token = auth_config.get("token")
headers["Authorization"] = f"Bearer {token}"
return url, headers, params
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""加密认证配置中的敏感信息
Args:
auth_config: 认证配置
encryption_key: 加密密钥
Returns:
加密后的认证配置
"""
try:
encrypted_config = auth_config.copy()
# 需要加密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in encrypted_config:
value = encrypted_config[field]
if isinstance(value, str) and value:
encrypted_value = self._encrypt_string(value, encryption_key)
encrypted_config[field] = encrypted_value
encrypted_config[f"{field}_encrypted"] = True
return encrypted_config
except Exception as e:
logger.error(f"加密认证配置失败: {e}")
return auth_config
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""解密认证配置中的敏感信息
Args:
encrypted_config: 加密的认证配置
encryption_key: 解密密钥
Returns:
解密后的认证配置
"""
try:
decrypted_config = encrypted_config.copy()
# 需要解密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
encrypted_value = decrypted_config[field]
if isinstance(encrypted_value, str) and encrypted_value:
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
decrypted_config[field] = decrypted_value
# 移除加密标记
decrypted_config.pop(f"{field}_encrypted", None)
return decrypted_config
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
"""加密字符串
Args:
value: 要加密的字符串
key: 加密密钥
Returns:
加密后的字符串Base64编码
"""
try:
# 使用HMAC-SHA256进行简单加密
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
# 生成HMAC
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
signature = hmac_obj.hexdigest()
# 组合原始值和签名然后Base64编码
combined = f"{value}:{signature}"
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
return encrypted
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
encrypted_value: 加密的字符串
key: 解密密钥
Returns:
解密后的字符串
"""
try:
# Base64解码
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
# 分离原始值和签名
if ':' not in decoded:
return encrypted_value # 可能不是加密的值
value, signature = decoded.rsplit(':', 1)
# 验证签名
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
expected_signature = hmac_obj.hexdigest()
if signature == expected_signature:
return value
else:
logger.warning("解密时签名验证失败")
return encrypted_value
except Exception as e:
logger.error(f"解密字符串失败: {e}")
return encrypted_value
def test_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str = None
) -> Dict[str, Any]:
"""测试认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL可选
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 如果没有测试URL只验证配置
if not test_url:
return {
"success": True,
"message": "认证配置有效",
"auth_type": auth_type.value
}
# 构建测试请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
return {
"success": True,
"message": "认证配置测试成功",
"auth_type": auth_type.value,
"test_url": test_url,
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
"has_auth_header": "Authorization" in headers
}
except Exception as e:
return {
"success": False,
"error": str(e),
"auth_type": auth_type.value if auth_type else "unknown"
}
async def test_authentication_with_request(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str,
timeout: int = 10
) -> Dict[str, Any]:
"""通过实际HTTP请求测试认证
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL
timeout: 超时时间(秒)
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 构建请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
# 发送测试请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(test_url, headers=headers) as response:
status_code = response.status
# 根据状态码判断认证是否成功
if status_code == 200:
return {
"success": True,
"message": "认证测试成功",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 401:
return {
"success": False,
"error": "认证失败 - 401 Unauthorized",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 403:
return {
"success": False,
"error": "认证失败 - 403 Forbidden",
"status_code": status_code,
"auth_type": auth_type.value
}
else:
return {
"success": True,
"message": f"请求成功,状态码: {status_code}",
"status_code": status_code,
"auth_type": auth_type.value
}
except aiohttp.ClientError as e:
return {
"success": False,
"error": f"网络请求失败: {e}",
"auth_type": auth_type.value
}
except Exception as e:
return {
"success": False,
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
auth_type: 认证类型
Returns:
配置模板
"""
templates = {
AuthType.NONE: {},
AuthType.API_KEY: {
"api_key": "",
"key_name": "X-API-Key",
"location": "header", # header, query, cookie
"description": "API Key认证配置"
},
AuthType.BEARER_TOKEN: {
"token": "",
"description": "Bearer Token认证配置"
}
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:
auth_config: 认证配置
Returns:
遮蔽敏感信息后的配置
"""
masked_config = auth_config.copy()
# 需要遮蔽的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in masked_config:
value = masked_config[field]
if isinstance(value, str) and len(value) > 4:
# 只显示前2位和后2位
masked_config[field] = f"{value[:2]}***{value[-2:]}"
elif isinstance(value, str) and value:
masked_config[field] = "***"
return masked_config

View File

@@ -0,0 +1,318 @@
"""自定义工具基类"""
import time
from typing import Dict, Any, List, Optional
import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class CustomTool(BaseTool):
"""自定义工具 - 基于OpenAPI schema的工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化自定义工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.schema_content = config.get("schema_content", {})
self.schema_url = config.get("schema_url")
self.auth_type = AuthType(config.get("auth_type", "none"))
self.auth_config = config.get("auth_config", {})
self.base_url = config.get("base_url", "")
self.timeout = config.get("timeout", 30)
# 解析schema
self._parsed_operations = self._parse_openapi_schema()
@property
def name(self) -> str:
"""工具名称"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
return f"custom_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("description", "自定义API工具")
return "自定义API工具"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.CUSTOM
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加操作选择参数
if len(self._parsed_operations) > 1:
params.append(ToolParameter(
name="operation",
type=ParameterType.STRING,
description="要执行的操作",
required=True,
enum=list(self._parsed_operations.keys())
))
# 添加通用参数(基于第一个操作的参数)
if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values()))
for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter(
name=param_name,
type=self._convert_openapi_type(param_info.get("type", "string")),
description=param_info.get("description", ""),
required=param_info.get("required", False),
default=param_info.get("default"),
enum=param_info.get("enum"),
minimum=param_info.get("minimum"),
maximum=param_info.get("maximum"),
pattern=param_info.get("pattern")
))
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行自定义工具"""
start_time = time.time()
try:
# 确定要执行的操作
operation_name = kwargs.get("operation")
if not operation_name and len(self._parsed_operations) == 1:
operation_name = next(iter(self._parsed_operations.keys()))
if not operation_name or operation_name not in self._parsed_operations:
raise ValueError(f"无效的操作: {operation_name}")
operation = self._parsed_operations[operation_name]
# 构建请求
url = self._build_request_url(operation, kwargs)
headers = self._build_request_headers(operation)
data = self._build_request_data(operation, kwargs)
# 发送HTTP请求
result = await self._send_http_request(
method=operation["method"],
url=url,
headers=headers,
data=data
)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="CUSTOM_TOOL_ERROR",
execution_time=execution_time
)
def _parse_openapi_schema(self) -> Dict[str, Any]:
"""解析OpenAPI schema"""
operations = {}
if not self.schema_content:
return operations
paths = self.schema_content.get("paths", {})
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "delete", "patch"]:
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
# 解析参数
parameters = {}
if "parameters" in operation:
for param in operation["parameters"]:
param_name = param.get("name")
param_schema = param.get("schema", {})
parameters[param_name] = {
"type": param_schema.get("type", "string"),
"description": param.get("description", ""),
"required": param.get("required", False),
"in": param.get("in", "query"),
**param_schema
}
# 解析请求体
request_body = None
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
if "application/json" in content:
request_body = content["application/json"].get("schema", {})
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": parameters,
"request_body": request_body
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(openapi_type, ParameterType.STRING)
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
"""构建请求URL"""
path = operation["path"]
# 替换路径参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "path" and param_name in params:
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
# 构建完整URL
if self.base_url:
url = urljoin(self.base_url, path.lstrip("/"))
else:
# 从schema中获取服务器URL
servers = self.schema_content.get("servers", [])
if servers:
base_url = servers[0].get("url", "")
url = urljoin(base_url, path.lstrip("/"))
else:
url = path
# 添加查询参数
query_params = {}
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "query" and param_name in params:
query_params[param_name] = params[param_name]
if query_params:
from urllib.parse import urlencode
url += "?" + urlencode(query_params)
return url
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"User-Agent": "CustomTool/1.0"
}
# 添加认证头
if self.auth_type == AuthType.API_KEY:
api_key = self.auth_config.get("api_key")
key_name = self.auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif self.auth_type == AuthType.BEARER_TOKEN:
token = self.auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
if request_body:
# 构建请求体数据
data = {}
properties = request_body.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name in params:
data[prop_name] = params[prop_name]
return data if data else None
return None
async def _send_http_request(
self,
method: str,
url: str,
headers: Dict[str, str],
data: Optional[Dict[str, Any]] = None
) -> Any:
"""发送HTTP请求"""
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs = {
"headers": headers
}
if data and method in ["POST", "PUT", "PATCH"]:
kwargs["json"] = data
async with session.request(method, url, **kwargs) as response:
if response.status >= 400:
error_text = await response.text()
raise Exception(f"HTTP {response.status}: {error_text}")
# 尝试解析JSON响应
try:
return await response.json()
except Exception as e:
return await response.text()
@classmethod
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从URL导入OpenAPI schema创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_url": schema_url,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
# 这里应该异步加载schema为了简化暂时返回空配置
return cls(tool_id, config)
@classmethod
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从schema字典创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_content": schema_dict,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
return cls(tool_id, config)

View File

@@ -0,0 +1,477 @@
"""OpenAPI Schema解析器"""
import json
import yaml
from typing import Dict, Any, List, Optional, Tuple
from urllib.parse import urlparse
import aiohttp
import asyncio
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
def __init__(self):
"""初始化解析器"""
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
"""从URL解析OpenAPI schema
Args:
schema_url: Schema URL
timeout: 超时时间(秒)
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 验证URL格式
parsed_url = urlparse(schema_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False, {}, "无效的URL格式"
# 下载schema
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(schema_url) as response:
if response.status != 200:
return False, {}, f"HTTP错误: {response.status}"
content_type = response.headers.get('content-type', '').lower()
content = await response.text()
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except asyncio.TimeoutError:
return False, {}, "请求超时"
except Exception as e:
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
return False, {}, str(e)
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
"""从内容解析OpenAPI schema
Args:
content: Schema内容
content_type: 内容类型
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
content: 内容字符串
content_type: 内容类型
Returns:
解析后的字典失败返回None
"""
try:
# 根据内容类型解析
if 'json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
else:
# 尝试自动检测格式
try:
return json.loads(content)
except json.JSONDecodeError:
try:
return yaml.safe_load(content)
except yaml.YAMLError:
return None
except Exception as e:
logger.error(f"解析内容失败: {e}")
return None
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
"""验证OpenAPI schema
Args:
schema_dict: Schema字典
Returns:
(是否有效, 错误信息)
"""
try:
# 检查基本结构
if not isinstance(schema_dict, dict):
return False, "Schema必须是JSON对象"
# 检查OpenAPI版本
openapi_version = schema_dict.get("openapi")
if not openapi_version:
return False, "缺少openapi版本字段"
if openapi_version not in self.supported_versions:
return False, f"不支持的OpenAPI版本: {openapi_version}"
# 检查必需字段
required_fields = ["info", "paths"]
for field in required_fields:
if field not in schema_dict:
return False, f"缺少必需字段: {field}"
# 验证info字段
info = schema_dict.get("info", {})
if not isinstance(info, dict):
return False, "info字段必须是对象"
if "title" not in info:
return False, "info.title字段是必需的"
# 验证paths字段
paths = schema_dict.get("paths", {})
if not isinstance(paths, dict):
return False, "paths字段必须是对象"
# 验证至少有一个路径
if not paths:
return False, "至少需要定义一个API路径"
return True, ""
except Exception as e:
return False, f"验证schema时出错: {e}"
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""从schema提取工具信息
Args:
schema_dict: Schema字典
Returns:
工具信息字典
"""
info = schema_dict.get("info", {})
return {
"name": info.get("title", "Custom API Tool"),
"description": info.get("description", ""),
"version": info.get("version", "1.0.0"),
"servers": schema_dict.get("servers", []),
"operations": self._extract_operations(schema_dict)
}
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""提取API操作信息
Args:
schema_dict: Schema字典
Returns:
操作信息字典
"""
operations = {}
paths = schema_dict.get("paths", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method, operation in path_item.items():
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
continue
if not isinstance(operation, dict):
continue
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
# 提取操作信息
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"tags": operation.get("tags", [])
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
operation: 操作定义
Returns:
参数信息字典
"""
parameters = {}
for param in operation.get("parameters", []):
if not isinstance(param, dict):
continue
param_name = param.get("name")
if not param_name:
continue
param_schema = param.get("schema", {})
parameters[param_name] = {
"name": param_name,
"in": param.get("in", "query"),
"description": param.get("description", ""),
"required": param.get("required", False),
"type": param_schema.get("type", "string"),
"format": param_schema.get("format"),
"enum": param_schema.get("enum"),
"default": param_schema.get("default"),
"minimum": param_schema.get("minimum"),
"maximum": param_schema.get("maximum"),
"pattern": param_schema.get("pattern"),
"example": param.get("example") or param_schema.get("example")
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
operation: 操作定义
Returns:
请求体信息如果没有返回None
"""
request_body = operation.get("requestBody")
if not request_body:
return None
content = request_body.get("content", {})
# 优先使用application/json
if "application/json" in content:
schema = content["application/json"].get("schema", {})
elif content:
# 使用第一个可用的内容类型
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema", {})
else:
return None
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
operation: 操作定义
Returns:
响应信息字典
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
if not isinstance(response, dict):
continue
content = response.get("content", {})
schema = None
# 尝试获取响应schema
if "application/json" in content:
schema = content["application/json"].get("schema")
elif content:
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema")
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"content_types": list(content.keys()) if content else []
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
operations: 操作信息字典
Returns:
参数定义列表
"""
parameters = []
# 如果有多个操作,添加操作选择参数
if len(operations) > 1:
parameters.append({
"name": "operation",
"type": "string",
"description": "要执行的操作",
"required": True,
"enum": list(operations.keys())
})
# 收集所有参数(去重)
all_params = {}
for operation_id, operation in operations.items():
# 路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_name not in all_params:
all_params[param_name] = {
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default"),
"minimum": param_info.get("minimum"),
"maximum": param_info.get("maximum"),
"pattern": param_info.get("pattern")
}
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name not in all_params:
all_params[prop_name] = {
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in schema.get("required", []),
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default"),
"minimum": prop_schema.get("minimum"),
"maximum": prop_schema.get("maximum"),
"pattern": prop_schema.get("pattern")
}
# 转换为参数列表
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
Args:
operation: 操作定义
params: 输入参数
Returns:
(是否有效, 错误信息列表)
"""
errors = []
# 验证路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("required", False) and param_name not in params:
errors.append(f"缺少必需参数: {param_name}")
if param_name in params:
value = params[param_name]
param_type = param_info.get("type", "string")
# 类型验证
if not self._validate_parameter_type(value, param_type):
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
# 枚举验证
enum_values = param_info.get("enum")
if enum_values and value not in enum_values:
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
# 验证请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
required_props = schema.get("required", [])
properties = schema.get("properties", {})
for prop_name in required_props:
if prop_name not in params:
errors.append(f"缺少必需的请求体参数: {prop_name}")
for prop_name, value in params.items():
if prop_name in properties:
prop_schema = properties[prop_name]
prop_type = prop_schema.get("type", "string")
if not self._validate_parameter_type(value, prop_type):
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
value: 参数值
expected_type: 期望类型
Returns:
是否类型匹配
"""
if value is None:
return True
type_mapping = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict
}
expected_python_type = type_mapping.get(expected_type)
if expected_python_type:
return isinstance(value, expected_python_type)
return True

View File

@@ -0,0 +1,501 @@
"""工具执行器 - 负责工具的实际调用和执行管理"""
import asyncio
import uuid
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import ToolExecution, ExecutionStatus
from app.core.tools.base import BaseTool, ToolResult
from app.core.tools.registry import ToolRegistry
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ExecutionContext:
"""执行上下文"""
def __init__(
self,
execution_id: str,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.execution_id = execution_id
self.tool_id = tool_id
self.user_id = user_id
self.workspace_id = workspace_id
self.timeout = timeout or 60.0 # 默认60秒超时
self.metadata = metadata or {}
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self.status = ExecutionStatus.PENDING
class ToolExecutor:
"""工具执行器 - 使用langchain标准接口执行工具"""
def __init__(self, db: Session, registry: ToolRegistry):
"""初始化工具执行器
Args:
db: 数据库会话
registry: 工具注册表
"""
self.db = db
self.registry = registry
self._running_executions: Dict[str, ExecutionContext] = {}
self._execution_lock = asyncio.Lock()
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
execution_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ToolResult:
"""执行工具
Args:
tool_id: 工具ID
parameters: 工具参数
user_id: 用户ID
workspace_id: 工作空间ID
execution_id: 执行ID可选自动生成
timeout: 超时时间(秒)
metadata: 额外元数据
Returns:
工具执行结果
"""
# 生成执行ID
if not execution_id:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ExecutionContext(
execution_id=execution_id,
tool_id=tool_id,
user_id=user_id,
workspace_id=workspace_id,
timeout=timeout,
metadata=metadata
)
try:
# 获取工具实例
tool = self.registry.get_tool(tool_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
error_code="TOOL_NOT_FOUND",
execution_time=0.0
)
# 记录执行开始
await self._record_execution_start(context, parameters)
# 执行工具
result = await self._execute_with_timeout(tool, parameters, context)
# 记录执行完成
await self._record_execution_complete(context, result)
return result
except Exception as e:
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
# 记录执行失败
error_result = ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=time.time() - context.started_at.timestamp()
)
await self._record_execution_complete(context, error_result)
return error_result
finally:
# 清理执行上下文
async with self._execution_lock:
if execution_id in self._running_executions:
del self._running_executions[execution_id]
async def execute_tools_batch(
self,
tool_executions: List[Dict[str, Any]],
max_concurrency: int = 5
) -> List[ToolResult]:
"""批量执行工具
Args:
tool_executions: 工具执行配置列表每个包含tool_id和parameters
max_concurrency: 最大并发数
Returns:
执行结果列表
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
async with semaphore:
return await self.execute_tool(
tool_id=exec_config["tool_id"],
parameters=exec_config.get("parameters", {}),
user_id=exec_config.get("user_id"),
workspace_id=exec_config.get("workspace_id"),
timeout=exec_config.get("timeout"),
metadata=exec_config.get("metadata")
)
# 并发执行所有工具
tasks = [execute_single(config) for config in tool_executions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(
ToolResult.error_result(
error=str(result),
error_code="BATCH_EXECUTION_ERROR",
execution_time=0.0
)
)
else:
processed_results.append(result)
return processed_results
async def cancel_execution(self, execution_id: str) -> bool:
"""取消工具执行
Args:
execution_id: 执行ID
Returns:
是否成功取消
"""
async with self._execution_lock:
if execution_id not in self._running_executions:
return False
context = self._running_executions[execution_id]
context.status = ExecutionStatus.FAILED
# 更新数据库记录
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution_record:
execution_record.status = ExecutionStatus.FAILED.value
execution_record.error_message = "执行被取消"
execution_record.completed_at = datetime.now()
self.db.commit()
logger.info(f"工具执行已取消: {execution_id}")
return True
def get_running_executions(self) -> List[Dict[str, Any]]:
"""获取正在运行的执行列表
Returns:
执行信息列表
"""
executions = []
for execution_id, context in self._running_executions.items():
executions.append({
"execution_id": execution_id,
"tool_id": context.tool_id,
"user_id": str(context.user_id) if context.user_id else None,
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
"started_at": context.started_at.isoformat(),
"status": context.status.value,
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
})
return executions
async def _execute_with_timeout(
self,
tool: BaseTool,
parameters: Dict[str, Any],
context: ExecutionContext
) -> ToolResult:
"""带超时的工具执行
Args:
tool: 工具实例
parameters: 参数
context: 执行上下文
Returns:
执行结果
"""
async with self._execution_lock:
self._running_executions[context.execution_id] = context
context.status = ExecutionStatus.RUNNING
try:
# 使用asyncio.wait_for实现超时控制
result = await asyncio.wait_for(
tool.safe_execute(**parameters),
timeout=context.timeout
)
context.status = ExecutionStatus.COMPLETED
return result
except asyncio.TimeoutError:
context.status = ExecutionStatus.TIMEOUT
return ToolResult.error_result(
error=f"工具执行超时({context.timeout}秒)",
error_code="EXECUTION_TIMEOUT",
execution_time=context.timeout
)
except Exception as e:
context.status = ExecutionStatus.FAILED
raise
async def _record_execution_start(
self,
context: ExecutionContext,
parameters: Dict[str, Any]
):
"""记录执行开始"""
try:
execution_record = ToolExecution(
execution_id=context.execution_id,
tool_config_id=uuid.UUID(context.tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=context.started_at,
user_id=context.user_id,
workspace_id=context.workspace_id
)
self.db.add(execution_record)
self.db.commit()
logger.debug(f"执行记录已创建: {context.execution_id}")
except Exception as e:
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
async def _record_execution_complete(
self,
context: ExecutionContext,
result: ToolResult
):
"""记录执行完成"""
try:
context.completed_at = datetime.now()
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == context.execution_id
).first()
if execution_record:
execution_record.status = (
ExecutionStatus.COMPLETED.value if result.success
else ExecutionStatus.FAILED.value
)
execution_record.output_data = result.data if result.success else None
execution_record.error_message = result.error if not result.success else None
execution_record.completed_at = context.completed_at
execution_record.execution_time = result.execution_time
execution_record.token_usage = result.token_usage
self.db.commit()
logger.debug(f"执行记录已更新: {context.execution_id}")
except Exception as e:
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
def get_execution_history(
self,
tool_id: Optional[str] = None,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""获取执行历史
Args:
tool_id: 工具ID过滤
user_id: 用户ID过滤
workspace_id: 工作空间ID过滤
limit: 返回数量限制
Returns:
执行历史列表
"""
try:
query = self.db.query(ToolExecution).order_by(
ToolExecution.started_at.desc()
)
if tool_id:
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
if user_id:
query = query.filter(ToolExecution.user_id == user_id)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.limit(limit).all()
history = []
for execution in executions:
history.append({
"execution_id": execution.execution_id,
"tool_id": str(execution.tool_config_id),
"status": execution.status,
"started_at": execution.started_at.isoformat() if execution.started_at else None,
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
"execution_time": execution.execution_time,
"user_id": str(execution.user_id) if execution.user_id else None,
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
"input_data": execution.input_data,
"output_data": execution.output_data,
"error_message": execution.error_message,
"token_usage": execution.token_usage
})
return history
except Exception as e:
logger.error(f"获取执行历史失败, 错误: {e}")
return []
def get_execution_statistics(
self,
workspace_id: Optional[uuid.UUID] = None,
days: int = 7
) -> Dict[str, Any]:
"""获取执行统计信息
Args:
workspace_id: 工作空间ID
days: 统计天数
Returns:
统计信息
"""
try:
from datetime import timedelta
start_date = datetime.now() - timedelta(days=days)
query = self.db.query(ToolExecution).filter(
ToolExecution.started_at >= start_date
)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.all()
# 统计数据
total_executions = len(executions)
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
# 平均执行时间
completed_executions = [e for e in executions if e.execution_time is not None]
avg_execution_time = (
sum(e.execution_time for e in completed_executions) / len(completed_executions)
if completed_executions else 0
)
# 按工具统计
tool_stats = {}
for execution in executions:
tool_id = str(execution.tool_config_id)
if tool_id not in tool_stats:
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
tool_stats[tool_id]["total"] += 1
if execution.status == ExecutionStatus.COMPLETED.value:
tool_stats[tool_id]["successful"] += 1
elif execution.status == ExecutionStatus.FAILED.value:
tool_stats[tool_id]["failed"] += 1
return {
"period_days": days,
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_execution_time": avg_execution_time,
"tool_statistics": tool_stats
}
except Exception as e:
logger.error(f"获取执行统计失败, 错误: {e}")
return {}
async def test_tool_connection(
self,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
from .mcp.client import MCPClient
tool_config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id)
).first()
if not tool_config:
return {"success": False, "message": "工具不存在"}
if tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
}
except:
await client.disconnect()
return {"success": False, "message": "MCP功能测试失败"}
else:
return {"success": False, "message": "MCP连接失败"}
else:
tool = self.registry.get_tool(tool_id)
if tool and hasattr(tool, 'test_connection'):
result = tool.test_connection()
return {"success": result.get("success", False), "message": result.get("message", "")}
return {"success": True, "message": "工具无需连接测试"}
except Exception as e:
return {"success": False, "message": "测试失败", "error": str(e)}

View File

@@ -0,0 +1,375 @@
"""Langchain适配器 - 将工具转换为langchain兼容格式"""
import json
from typing import Dict, Any, List, Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool as LangchainBaseTool
from langchain_core.tools import ToolException
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class LangchainToolWrapper(LangchainBaseTool):
"""Langchain工具包装器"""
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema")
return_direct: bool = Field(False, description="是否直接返回结果")
# 内部工具实例
tool_instance: BaseTool = Field(..., description="内部工具实例")
class Config:
arbitrary_types_allowed = True
def __init__(self, tool_instance: BaseTool, **kwargs):
"""初始化Langchain工具包装器
Args:
tool_instance: 内部工具实例
"""
# 动态创建参数schema
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
super().__init__(
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
**kwargs
)
def _run(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""同步执行工具Langchain要求"""
# 由于我们的工具是异步的,这里抛出异常提示使用异步版本
raise NotImplementedError("请使用 _arun 方法进行异步调用")
async def _arun(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""异步执行工具"""
try:
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
except Exception as e:
logger.error(f"工具执行失败: {self.name}, 错误: {e}")
raise ToolException(f"工具执行失败: {str(e)}")
class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
Args:
tools: 工具列表
Returns:
Langchain工具列表
"""
converted_tools = []
for tool in tools:
try:
converted_tool = LangchainAdapter.convert_tool(tool)
converted_tools.append(converted_tool)
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
return converted_tools
@staticmethod
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
"""根据工具参数创建Pydantic schema
Args:
parameters: 工具参数列表
Returns:
Pydantic模型类
"""
# 构建字段定义
fields = {}
annotations = {}
for param in parameters:
# 确定Python类型
python_type = LangchainAdapter._get_python_type(param.type)
# 处理可选参数
if not param.required:
python_type = Optional[python_type]
# 创建Field定义
field_kwargs = {
"description": param.description
}
if param.default is not None:
field_kwargs["default"] = param.default
elif not param.required:
field_kwargs["default"] = None
else:
field_kwargs["default"] = ... # 必需字段
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
if param.maximum is not None:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["regex"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type
# 动态创建Pydantic模型
schema_class = type(
"ToolArgsSchema",
(BaseModel,),
{
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
}
)
return schema_class
@staticmethod
def _get_python_type(param_type: ParameterType) -> type:
"""获取参数类型对应的Python类型
Args:
param_type: 参数类型
Returns:
Python类型
"""
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: float,
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
return type_mapping.get(param_type, str)
@staticmethod
def _format_result_for_langchain(result: ToolResult) -> str:
"""将工具结果格式化为Langchain标准格式
Args:
result: 工具执行结果
Returns:
格式化的字符串结果
"""
if not result.success:
# 错误结果
error_info = {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}
return json.dumps(error_info, ensure_ascii=False, indent=2)
# 成功结果
if isinstance(result.data, str):
# 如果数据已经是字符串,直接返回
return result.data
elif isinstance(result.data, (dict, list)):
# 如果是结构化数据转换为JSON
return json.dumps(result.data, ensure_ascii=False, indent=2)
else:
# 其他类型转换为字符串
return str(result.data)
@staticmethod
def create_tool_description(tool: BaseTool) -> Dict[str, Any]:
"""创建工具描述(用于工具发现和文档生成)
Args:
tool: 工具实例
Returns:
工具描述字典
"""
return {
"name": tool.name,
"description": tool.description,
"tool_type": tool.tool_type.value,
"version": tool.version,
"status": tool.status.value,
"tags": tool.tags,
"parameters": [
{
"name": param.name,
"type": param.type.value,
"description": param.description,
"required": param.required,
"default": param.default,
"enum": param.enum,
"minimum": param.minimum,
"maximum": param.maximum,
"pattern": param.pattern
}
for param in tool.parameters
],
"langchain_compatible": True
}
@staticmethod
def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]:
"""验证工具是否与Langchain兼容
Args:
tool: 工具实例
Returns:
(是否兼容, 问题列表)
"""
issues = []
# 检查工具名称
if not tool.name or not isinstance(tool.name, str):
issues.append("工具名称必须是非空字符串")
# 检查工具描述
if not tool.description or not isinstance(tool.description, str):
issues.append("工具描述必须是非空字符串")
# 检查参数定义
for param in tool.parameters:
if not param.name or not isinstance(param.name, str):
issues.append(f"参数名称无效: {param.name}")
if param.type not in ParameterType:
issues.append(f"不支持的参数类型: {param.type}")
if param.required and param.default is not None:
issues.append(f"必需参数不应有默认值: {param.name}")
# 检查是否有execute方法
if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')):
issues.append("工具必须实现execute方法")
return len(issues) == 0, issues
@staticmethod
def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]:
"""获取Langchain工具的OpenAPI schema
Args:
tool: 工具实例
Returns:
OpenAPI schema字典
"""
# 构建参数schema
properties = {}
required = []
for param in tool.parameters:
prop_schema = {
"type": LangchainAdapter._get_openapi_type(param.type),
"description": param.description
}
if param.enum:
prop_schema["enum"] = param.enum
if param.minimum is not None:
prop_schema["minimum"] = param.minimum
if param.maximum is not None:
prop_schema["maximum"] = param.maximum
if param.pattern:
prop_schema["pattern"] = param.pattern
if param.default is not None:
prop_schema["default"] = param.default
properties[param.name] = prop_schema
if param.required:
required.append(param.name)
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
@staticmethod
def _get_openapi_type(param_type: ParameterType) -> str:
"""获取OpenAPI类型
Args:
param_type: 参数类型
Returns:
OpenAPI类型字符串
"""
type_mapping = {
ParameterType.STRING: "string",
ParameterType.INTEGER: "integer",
ParameterType.NUMBER: "number",
ParameterType.BOOLEAN: "boolean",
ParameterType.ARRAY: "array",
ParameterType.OBJECT: "object"
}
return type_mapping.get(param_type, "string")

View File

@@ -0,0 +1,12 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
__all__ = [
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPServiceManager"
]

View File

@@ -0,0 +1,258 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
execution_time=execution_time
)
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
self._client = None
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,626 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPConnectionError(Exception):
"""MCP连接错误"""
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
}
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
request_id = str(request_data["id"])
if self.connection_type == "websocket":
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
self._last_health_check = time.time()
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}

View File

@@ -0,0 +1,604 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
"""初始化MCP服务管理器
Args:
db: 数据库会话
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
config_data={
"server_url": server_url,
"connection_config": connection_config
}
)
self.db.add(tool_config)
self.db.flush()
# 创建MCP特定配置
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}

View File

@@ -0,0 +1,436 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")

View File

@@ -4,8 +4,9 @@
基于 LangGraph 的工作流执行引擎。
"""
import datetime
import logging
# import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
@@ -15,6 +16,11 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter
# TOOL_MANAGEMENT_AVAILABLE = True
# from app.db import get_db
logger = logging.getLogger(__name__)
@@ -457,3 +463,179 @@ async def execute_workflow_stream(
)
async for event in executor.execute_stream(input_data):
yield event
# ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
# """获取工作流可用的工具列表
#
# Args:
# workspace_id: 工作空间ID
# user_id: 用户ID
#
# Returns:
# 可用工具列表
# """
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.warning("工具管理系统不可用")
# return []
#
# try:
# db = next(get_db())
#
# # 创建工具注册表
# registry = ToolRegistry(db)
#
# # 注册内置工具类
# from app.core.tools.builtin import (
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
# )
# registry.register_tool_class(DateTimeTool)
# registry.register_tool_class(JsonTool)
# registry.register_tool_class(BaiduSearchTool)
# registry.register_tool_class(MinerUTool)
# registry.register_tool_class(TextInTool)
#
# # 获取活跃的工具
# import uuid
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
# active_tools = [tool for tool in tools if tool.status.value == "active"]
#
# # 转换为Langchain工具
# langchain_tools = []
# for tool_info in active_tools:
# try:
# tool_instance = registry.get_tool(tool_info.id)
# if tool_instance:
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
# langchain_tools.append(langchain_tool)
# except Exception as e:
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
#
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
# return langchain_tools
#
# except Exception as e:
# logger.error(f"获取工作流工具失败: {e}")
# return []
#
#
# class ToolWorkflowNode:
# """工具工作流节点 - 在工作流中执行工具"""
#
# def __init__(self, node_config: dict, workflow_config: dict):
# """初始化工具节点
#
# Args:
# node_config: 节点配置
# workflow_config: 工作流配置
# """
# self.node_config = node_config
# self.workflow_config = workflow_config
# self.tool_id = node_config.get("tool_id")
# self.tool_parameters = node_config.get("parameters", {})
#
# async def run(self, state: WorkflowState) -> WorkflowState:
# """执行工具节点"""
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.error("工具管理系统不可用")
# state["error"] = "工具管理系统不可用"
# return state
#
# try:
# from sqlalchemy.orm import Session
# db = next(get_db())
#
# # 创建工具执行器
# registry = ToolRegistry(db)
# executor = ToolExecutor(db, registry)
#
# # 准备参数(支持变量替换)
# parameters = self._prepare_parameters(state)
#
# # 执行工具
# result = await executor.execute_tool(
# tool_id=self.tool_id,
# parameters=parameters,
# user_id=uuid.UUID(state["user_id"]),
# workspace_id=uuid.UUID(state["workspace_id"])
# )
#
# # 更新状态
# node_id = self.node_config.get("id")
# if result.success:
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "output": result.data,
# "execution_time": result.execution_time,
# "token_usage": result.token_usage
# }
#
# # 更新运行时变量
# if isinstance(result.data, dict):
# for key, value in result.data.items():
# state["runtime_vars"][f"{node_id}.{key}"] = value
# else:
# state["runtime_vars"][f"{node_id}.result"] = result.data
# else:
# state["error"] = result.error
# state["error_node"] = node_id
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "error": result.error,
# "execution_time": result.execution_time
# }
#
# return state
#
# except Exception as e:
# logger.error(f"工具节点执行失败: {e}")
# state["error"] = str(e)
# state["error_node"] = self.node_config.get("id")
# return state
#
# def _prepare_parameters(self, state: WorkflowState) -> dict:
# """准备工具参数(支持变量替换)"""
# parameters = {}
#
# for key, value in self.tool_parameters.items():
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# # 变量替换
# var_path = value[2:-1]
#
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
# if "." in var_path:
# parts = var_path.split(".")
# current = state.get("variables", {})
#
# for part in parts:
# if isinstance(current, dict) and part in current:
# current = current[part]
# else:
# # 尝试从运行时变量获取
# runtime_key = ".".join(parts)
# current = state.get("runtime_vars", {}).get(runtime_key, value)
# break
#
# parameters[key] = current
# else:
# # 简单变量
# variables = state.get("variables", {})
# parameters[key] = variables.get(var_path, value)
# else:
# parameters[key] = value
#
# return parameters
#
#
# # 注册工具节点到NodeFactory如果存在
# try:
# from app.core.workflow.nodes import NodeFactory
# if hasattr(NodeFactory, 'register_node_type'):
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
# logger.info("工具节点已注册到工作流系统")
# except Exception as e:
# logger.warning(f"注册工具节点失败: {e}")

View File

@@ -21,6 +21,10 @@ from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
from .tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
__all__ = [
"Tenants",
@@ -57,5 +61,15 @@ __all__ = [
"WorkflowNodeExecution",
"RetrievalInfo",
"PromptOptimizerSession",
"PromptOptimizerSessionHistory"
"PromptOptimizerSessionHistory",
"RetrievalInfo",
"ToolConfig",
"BuiltinToolConfig",
"CustomToolConfig",
"MCPToolConfig",
"ToolExecution",
"ToolType",
"ToolStatus",
"AuthType",
"ExecutionStatus"
]

View File

@@ -21,3 +21,6 @@ class Tenants(Base):
# Relationship to workspaces owned by the tenant
owned_workspaces = relationship("Workspace", back_populates="tenant")
# Relationship to tool configs owned by the tenant
tool_configs = relationship("ToolConfig", back_populates="tenant")

View File

@@ -0,0 +1,226 @@
"""工具管理相关数据模型"""
import uuid
from datetime import datetime
from enum import StrEnum
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class ToolType(StrEnum):
"""工具类型枚举"""
BUILTIN = "builtin"
CUSTOM = "custom"
MCP = "mcp"
class ToolStatus(StrEnum):
"""工具状态枚举"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
LOADING = "loading"
class AuthType(StrEnum):
"""认证类型枚举"""
NONE = "none"
API_KEY = "api_key"
BEARER_TOKEN = "bearer_token"
class ExecutionStatus(StrEnum):
"""执行状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
class ToolConfig(Base):
"""工具配置基础模型"""
__tablename__ = "tool_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(255), nullable=False, index=True)
description = Column(Text)
tool_type = Column(String(50), nullable=False, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
# 工具特定配置JSON格式存储
config_data = Column(JSON, default=dict)
# 元数据
version = Column(String(50), default="1.0.0")
tags = Column(JSON, default=list) # 标签列表
# 时间戳
created_at = Column(DateTime, default=datetime.now, nullable=False)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
# 关联关系
tenant = relationship("Tenants", back_populates="tool_configs")
executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan")
def __repr__(self):
return f"<ToolConfig(id={self.id}, name={self.name}, type={self.tool_type}, status={self.status})>"
class BuiltinToolConfig(Base):
"""内置工具配置模型"""
__tablename__ = "builtin_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
tool_class = Column(String(255), nullable=False) # 工具类名
parameters = Column(JSON, default=dict) # 工具参数配置
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
class CustomToolConfig(Base):
"""自定义工具配置模型"""
__tablename__ = "custom_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
schema_url = Column(String(1000)) # OpenAPI schema URL
schema_content = Column(JSON) # OpenAPI schema 内容
# 认证配置
auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False)
auth_config = Column(JSON, default=dict) # 认证配置(加密存储)
# API配置
base_url = Column(String(1000)) # API基础URL
timeout = Column(Integer, default=30) # 超时时间(秒)
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
class MCPToolConfig(Base):
"""MCP工具配置模型"""
__tablename__ = "mcp_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
server_url = Column(String(1000), nullable=False) # MCP服务器URL
connection_config = Column(JSON, default=dict) # 连接配置
# 服务状态
last_health_check = Column(DateTime)
health_status = Column(String(50), default="unknown")
error_message = Column(Text)
# 可用工具列表
available_tools = Column(JSON, default=list)
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<MCPToolConfig(id={self.id}, server_url={self.server_url})>"
class ToolExecution(Base):
"""工具执行记录模型"""
__tablename__ = "tool_executions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True)
# 执行信息
execution_id = Column(String(255), nullable=False, index=True) # 执行ID可用于关联工作流等
status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True)
# 输入输出
input_data = Column(JSON) # 输入参数
output_data = Column(JSON) # 输出结果
error_message = Column(Text) # 错误信息
# 性能指标
started_at = Column(DateTime, nullable=False, index=True)
completed_at = Column(DateTime)
execution_time = Column(Float) # 执行时间(秒)
# Token使用情况如果适用
token_usage = Column(JSON)
# 用户信息
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True)
# 关联关系
tool_config = relationship("ToolConfig", back_populates="executions")
user = relationship("User")
workspace = relationship("Workspace")
def __repr__(self):
return f"<ToolExecution(id={self.id}, status={self.status}, tool={self.tool_config_id})>"
# class ToolDependency(Base):
# """工具依赖关系模型"""
# __tablename__ = "tool_dependencies"
#
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
#
# # 依赖类型和版本要求
# dependency_type = Column(String(50), default="required") # required, optional
# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0"
#
# # 时间戳
# created_at = Column(DateTime, default=datetime.now, nullable=False)
#
# # 关联关系
# tool = relationship("ToolConfig", foreign_keys=[tool_id])
# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id])
#
# def __repr__(self):
# return f"<ToolDependency(tool={self.tool_id}, depends_on={self.depends_on_tool_id})>"
# class PluginConfig(Base):
# """插件配置模型"""
# __tablename__ = "plugin_configs"
#
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# name = Column(String(255), nullable=False, unique=True, index=True)
# description = Column(Text)
#
# # 插件信息
# plugin_path = Column(String(1000), nullable=False) # 插件文件路径
# entry_point = Column(String(255), nullable=False) # 入口点
# version = Column(String(50), default="1.0.0")
#
# # 状态
# is_enabled = Column(Boolean, default=True, nullable=False)
# is_loaded = Column(Boolean, default=False, nullable=False)
# load_error = Column(Text) # 加载错误信息
#
# # 配置
# config_schema = Column(JSON) # 配置schema
# config_data = Column(JSON, default=dict) # 配置数据
#
# # 依赖
# dependencies = Column(JSON, default=list) # 依赖的其他插件
#
# # 时间戳
# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
# last_loaded_at = Column(DateTime)
#
# def __repr__(self):
# return f"<PluginConfig(id={self.id}, name={self.name}, version={self.version})>"

View File

@@ -14,6 +14,7 @@ from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.repositories import workspace_repository, knowledge_repository
logger = get_business_logger()
@@ -328,4 +329,4 @@ def create_agent_invocation_tool(
)
return f"调用 Agent 失败: {str(e)}"
return invoke_agent
return invoke_agent

374
api/test_tool_system.py Normal file
View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python3
"""
工具管理系统基础测试脚本
用于验证系统的基本功能是否正常
"""
import asyncio
import uuid
from datetime import datetime
# 测试导入
def test_imports():
"""测试模块导入"""
print("测试模块导入...")
try:
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
print("✓ 基础工具模块导入成功")
except ImportError as e:
print(f"✗ 基础工具模块导入失败: {e}")
return False
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
from app.core.tools.builtin.json_tool import JsonTool
print("✓ 内置工具模块导入成功")
except ImportError as e:
print(f"✗ 内置工具模块导入失败: {e}")
return False
try:
from app.core.tools.langchain_adapter import LangchainAdapter
print("✓ Langchain适配器导入成功")
except ImportError as e:
print(f"✗ Langchain适配器导入失败: {e}")
return False
try:
from app.models.tool_model import ToolConfig, ToolType, ToolStatus
print("✓ 工具模型导入成功")
except ImportError as e:
print(f"✗ 工具模型导入失败: {e}")
return False
try:
from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager
print("✓ 自定义工具模块导入成功")
except ImportError as e:
print(f"✗ 自定义工具模块导入失败: {e}")
return False
try:
from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager
print("✓ MCP工具模块导入成功")
except ImportError as e:
print(f"✗ MCP工具模块导入失败: {e}")
return False
return True
def test_tool_creation():
"""测试工具创建"""
print("\n测试工具创建...")
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
# 创建时间工具实例(全局工具)
tool_id = str(uuid.uuid4())
config = {
"parameters": {"timezone": "UTC"},
"tenant_id": None, # 全局工具
"version": "1.0.0",
"tags": ["time", "utility", "builtin"]
}
datetime_tool = DateTimeTool(tool_id, config)
# 验证工具属性
assert datetime_tool.name == "datetime_tool"
assert datetime_tool.tool_type.value == "builtin"
assert len(datetime_tool.parameters) > 0
print("✓ 时间工具创建成功(全局工具)")
return True
except Exception as e:
print(f"✗ 工具创建失败: {e}")
return False
async def test_tool_execution():
"""测试工具执行"""
print("\n测试工具执行...")
try:
from app.core.tools.builtin.datetime_tool import DateTimeTool
# 创建时间工具实例
tool_id = str(uuid.uuid4())
config = {
"parameters": {"timezone": "UTC"},
"tenant_id": None, # 全局工具
"version": "1.0.0"
}
datetime_tool = DateTimeTool(tool_id, config)
# 测试获取当前时间
result = await datetime_tool.safe_execute(operation="now")
assert result.success == True
assert "datetime" in result.data
assert result.execution_time > 0
print("✓ 工具执行成功")
print(f" 执行时间: {result.execution_time:.3f}")
print(f" 返回数据: {result.data}")
return True
except Exception as e:
print(f"✗ 工具执行失败: {e}")
return False
def test_langchain_adapter():
"""测试Langchain适配器"""
print("\n测试Langchain适配器...")
try:
from app.core.tools.builtin.json_tool import JsonTool
from app.core.tools.langchain_adapter import LangchainAdapter
# 创建JSON工具实例
tool_id = str(uuid.uuid4())
config = {
"parameters": {"indent": 2},
"tenant_id": None, # 全局工具
"version": "1.0.0"
}
json_tool = JsonTool(tool_id, config)
# 验证Langchain兼容性
is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool)
if not is_compatible:
print(f"✗ Langchain兼容性验证失败: {issues}")
return False
# 创建工具描述
description = LangchainAdapter.create_tool_description(json_tool)
assert "name" in description
assert "parameters" in description
assert description["langchain_compatible"] == True
print("✓ Langchain适配器测试成功")
return True
except Exception as e:
print(f"✗ Langchain适配器测试失败: {e}")
return False
def test_config_manager():
"""测试配置管理器"""
print("\n测试配置管理器...")
try:
from app.core.tools.config_manager import ConfigManager
# 创建配置管理器
config_manager = ConfigManager()
# 获取配置摘要
summary = config_manager.get_config_summary()
assert "config_dir" in summary
assert "total_configs" in summary
print("✓ 配置管理器测试成功")
print(f" 配置目录: {summary['config_dir']}")
print(f" 总配置数: {summary['total_configs']}")
return True
except Exception as e:
print(f"✗ 配置管理器测试失败: {e}")
return False
def test_schema_parser():
"""测试OpenAPI Schema解析器"""
print("\n测试OpenAPI Schema解析器...")
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
# 创建解析器
parser = OpenAPISchemaParser()
# 测试简单的OpenAPI schema
test_schema = {
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0",
"description": "测试API"
},
"paths": {
"/test": {
"get": {
"summary": "测试接口",
"operationId": "test_operation",
"responses": {
"200": {
"description": "成功"
}
}
}
}
}
}
# 验证schema
is_valid, error_msg = parser.validate_schema(test_schema)
assert is_valid, f"Schema验证失败: {error_msg}"
# 提取工具信息
tool_info = parser.extract_tool_info(test_schema)
assert tool_info["name"] == "Test API"
assert "test_operation" in tool_info["operations"]
print("✓ OpenAPI Schema解析器测试成功")
return True
except Exception as e:
print(f"✗ OpenAPI Schema解析器测试失败: {e}")
return False
def test_auth_manager():
"""测试认证管理器"""
print("\n测试认证管理器...")
try:
from app.core.tools.custom.auth_manager import AuthManager
from app.models.tool_model import AuthType
# 创建认证管理器
auth_manager = AuthManager()
# 测试API Key认证配置
api_key_config = {
"api_key": "test-key-123",
"key_name": "X-API-Key",
"location": "header"
}
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config)
assert is_valid, f"API Key配置验证失败: {error_msg}"
# 测试Bearer Token认证配置
bearer_config = {
"token": "bearer-token-123"
}
is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config)
assert is_valid, f"Bearer Token配置验证失败: {error_msg}"
# 测试认证应用
url = "https://api.example.com/test"
headers = {}
params = {}
new_url, new_headers, new_params = auth_manager.apply_authentication(
AuthType.API_KEY, api_key_config, url, headers, params
)
assert "X-API-Key" in new_headers
assert new_headers["X-API-Key"] == "test-key-123"
print("✓ 认证管理器测试成功")
return True
except Exception as e:
print(f"✗ 认证管理器测试失败: {e}")
return False
def test_builtin_initializer():
"""测试内置工具初始化器"""
print("\n测试内置工具初始化器...")
try:
from app.core.tools.builtin_initializer import BuiltinToolInitializer
# 注意:这里不能真正初始化,因为需要数据库连接
# 只测试类的创建和基本方法
# 模拟数据库会话(实际使用中需要真实的数据库连接)
class MockDB:
def query(self, *args):
return self
def filter(self, *args):
return self
def first(self):
return None
def all(self):
return []
mock_db = MockDB()
initializer = BuiltinToolInitializer(mock_db)
# 测试获取内置工具状态(会返回空列表,因为没有真实数据)
status = initializer.get_builtin_tools_status()
assert isinstance(status, list)
print("✓ 内置工具初始化器测试成功")
return True
except Exception as e:
print(f"✗ 内置工具初始化器测试失败: {e}")
return False
async def main():
"""主测试函数"""
print("=" * 50)
print("工具管理系统基础测试")
print("=" * 50)
tests = [
("模块导入", test_imports),
("工具创建", test_tool_creation),
("工具执行", test_tool_execution),
("Langchain适配", test_langchain_adapter),
("配置管理", test_config_manager),
("Schema解析器", test_schema_parser),
("认证管理器", test_auth_manager),
("内置工具初始化器", test_builtin_initializer)
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
try:
if asyncio.iscoroutinefunction(test_func):
result = await test_func()
else:
result = test_func()
if result:
passed += 1
except Exception as e:
print(f"{test_name}测试异常: {e}")
print("\n" + "=" * 50)
print(f"测试结果: {passed}/{total} 通过")
if passed == total:
print("🎉 所有基础测试通过!工具管理系统基本功能正常。")
return True
else:
print("⚠️ 部分测试失败,请检查相关模块。")
return False
if __name__ == "__main__":
asyncio.run(main())