feat(apikey system): tool system development

This commit is contained in:
谢俊男
2025-12-20 15:24:28 +08:00
parent 3fbd4f206e
commit c26af11f76
39 changed files with 9338 additions and 4 deletions

View File

@@ -0,0 +1,11 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
__all__ = [
"CustomTool",
"OpenAPISchemaParser",
"AuthManager"
]

View File

@@ -0,0 +1,525 @@
"""认证管理器 - 处理自定义工具的认证配置"""
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
from app.models.tool_model import AuthType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AuthManager:
"""认证管理器 - 支持多种认证方式"""
def __init__(self):
"""初始化认证管理器"""
self.supported_auth_types = [
AuthType.NONE,
AuthType.API_KEY,
AuthType.BEARER_TOKEN
]
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
try:
if auth_type not in self.supported_auth_types:
return False, f"不支持的认证类型: {auth_type}"
if auth_type == AuthType.NONE:
return True, ""
elif auth_type == AuthType.API_KEY:
return self._validate_api_key_config(auth_config)
elif auth_type == AuthType.BEARER_TOKEN:
return self._validate_bearer_token_config(auth_config)
return False, "未知的认证类型"
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
api_key = auth_config.get("api_key")
if not api_key:
return False, "API Key不能为空"
if not isinstance(api_key, str):
return False, "API Key必须是字符串"
# 验证key名称
key_name = auth_config.get("key_name", "X-API-Key")
if not isinstance(key_name, str):
return False, "API Key名称必须是字符串"
# 验证位置
key_location = auth_config.get("location", "header")
if key_location not in ["header", "query", "cookie"]:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
token = auth_config.get("token")
if not token:
return False, "Bearer Token不能为空"
if not isinstance(token, str):
return False, "Bearer Token必须是字符串"
return True, ""
def apply_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用认证到请求
Args:
auth_type: 认证类型
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
try:
if auth_type == AuthType.NONE:
return url, headers, params
elif auth_type == AuthType.API_KEY:
return self._apply_api_key_auth(auth_config, url, headers, params)
elif auth_type == AuthType.BEARER_TOKEN:
return self._apply_bearer_token_auth(auth_config, url, headers, params)
else:
logger.warning(f"不支持的认证类型: {auth_type}")
return url, headers, params
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用API Key认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
location = auth_config.get("location", "header")
if location == "header":
headers[key_name] = api_key
elif location == "query":
# 添加到URL查询参数
separator = "&" if "?" in url else "?"
encoded_key = quote(str(api_key))
url += f"{separator}{key_name}={encoded_key}"
elif location == "cookie":
# 添加到Cookie头
cookie_value = f"{key_name}={api_key}"
if "Cookie" in headers:
headers["Cookie"] += f"; {cookie_value}"
else:
headers["Cookie"] = cookie_value
return url, headers, params
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用Bearer Token认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
token = auth_config.get("token")
headers["Authorization"] = f"Bearer {token}"
return url, headers, params
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""加密认证配置中的敏感信息
Args:
auth_config: 认证配置
encryption_key: 加密密钥
Returns:
加密后的认证配置
"""
try:
encrypted_config = auth_config.copy()
# 需要加密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in encrypted_config:
value = encrypted_config[field]
if isinstance(value, str) and value:
encrypted_value = self._encrypt_string(value, encryption_key)
encrypted_config[field] = encrypted_value
encrypted_config[f"{field}_encrypted"] = True
return encrypted_config
except Exception as e:
logger.error(f"加密认证配置失败: {e}")
return auth_config
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""解密认证配置中的敏感信息
Args:
encrypted_config: 加密的认证配置
encryption_key: 解密密钥
Returns:
解密后的认证配置
"""
try:
decrypted_config = encrypted_config.copy()
# 需要解密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
encrypted_value = decrypted_config[field]
if isinstance(encrypted_value, str) and encrypted_value:
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
decrypted_config[field] = decrypted_value
# 移除加密标记
decrypted_config.pop(f"{field}_encrypted", None)
return decrypted_config
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
"""加密字符串
Args:
value: 要加密的字符串
key: 加密密钥
Returns:
加密后的字符串Base64编码
"""
try:
# 使用HMAC-SHA256进行简单加密
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
# 生成HMAC
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
signature = hmac_obj.hexdigest()
# 组合原始值和签名然后Base64编码
combined = f"{value}:{signature}"
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
return encrypted
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
encrypted_value: 加密的字符串
key: 解密密钥
Returns:
解密后的字符串
"""
try:
# Base64解码
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
# 分离原始值和签名
if ':' not in decoded:
return encrypted_value # 可能不是加密的值
value, signature = decoded.rsplit(':', 1)
# 验证签名
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
expected_signature = hmac_obj.hexdigest()
if signature == expected_signature:
return value
else:
logger.warning("解密时签名验证失败")
return encrypted_value
except Exception as e:
logger.error(f"解密字符串失败: {e}")
return encrypted_value
def test_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str = None
) -> Dict[str, Any]:
"""测试认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL可选
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 如果没有测试URL只验证配置
if not test_url:
return {
"success": True,
"message": "认证配置有效",
"auth_type": auth_type.value
}
# 构建测试请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
return {
"success": True,
"message": "认证配置测试成功",
"auth_type": auth_type.value,
"test_url": test_url,
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
"has_auth_header": "Authorization" in headers
}
except Exception as e:
return {
"success": False,
"error": str(e),
"auth_type": auth_type.value if auth_type else "unknown"
}
async def test_authentication_with_request(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str,
timeout: int = 10
) -> Dict[str, Any]:
"""通过实际HTTP请求测试认证
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL
timeout: 超时时间(秒)
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 构建请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
# 发送测试请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(test_url, headers=headers) as response:
status_code = response.status
# 根据状态码判断认证是否成功
if status_code == 200:
return {
"success": True,
"message": "认证测试成功",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 401:
return {
"success": False,
"error": "认证失败 - 401 Unauthorized",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 403:
return {
"success": False,
"error": "认证失败 - 403 Forbidden",
"status_code": status_code,
"auth_type": auth_type.value
}
else:
return {
"success": True,
"message": f"请求成功,状态码: {status_code}",
"status_code": status_code,
"auth_type": auth_type.value
}
except aiohttp.ClientError as e:
return {
"success": False,
"error": f"网络请求失败: {e}",
"auth_type": auth_type.value
}
except Exception as e:
return {
"success": False,
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
auth_type: 认证类型
Returns:
配置模板
"""
templates = {
AuthType.NONE: {},
AuthType.API_KEY: {
"api_key": "",
"key_name": "X-API-Key",
"location": "header", # header, query, cookie
"description": "API Key认证配置"
},
AuthType.BEARER_TOKEN: {
"token": "",
"description": "Bearer Token认证配置"
}
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:
auth_config: 认证配置
Returns:
遮蔽敏感信息后的配置
"""
masked_config = auth_config.copy()
# 需要遮蔽的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in masked_config:
value = masked_config[field]
if isinstance(value, str) and len(value) > 4:
# 只显示前2位和后2位
masked_config[field] = f"{value[:2]}***{value[-2:]}"
elif isinstance(value, str) and value:
masked_config[field] = "***"
return masked_config

View File

@@ -0,0 +1,318 @@
"""自定义工具基类"""
import time
from typing import Dict, Any, List, Optional
import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class CustomTool(BaseTool):
"""自定义工具 - 基于OpenAPI schema的工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化自定义工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.schema_content = config.get("schema_content", {})
self.schema_url = config.get("schema_url")
self.auth_type = AuthType(config.get("auth_type", "none"))
self.auth_config = config.get("auth_config", {})
self.base_url = config.get("base_url", "")
self.timeout = config.get("timeout", 30)
# 解析schema
self._parsed_operations = self._parse_openapi_schema()
@property
def name(self) -> str:
"""工具名称"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
return f"custom_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("description", "自定义API工具")
return "自定义API工具"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.CUSTOM
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加操作选择参数
if len(self._parsed_operations) > 1:
params.append(ToolParameter(
name="operation",
type=ParameterType.STRING,
description="要执行的操作",
required=True,
enum=list(self._parsed_operations.keys())
))
# 添加通用参数(基于第一个操作的参数)
if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values()))
for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter(
name=param_name,
type=self._convert_openapi_type(param_info.get("type", "string")),
description=param_info.get("description", ""),
required=param_info.get("required", False),
default=param_info.get("default"),
enum=param_info.get("enum"),
minimum=param_info.get("minimum"),
maximum=param_info.get("maximum"),
pattern=param_info.get("pattern")
))
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行自定义工具"""
start_time = time.time()
try:
# 确定要执行的操作
operation_name = kwargs.get("operation")
if not operation_name and len(self._parsed_operations) == 1:
operation_name = next(iter(self._parsed_operations.keys()))
if not operation_name or operation_name not in self._parsed_operations:
raise ValueError(f"无效的操作: {operation_name}")
operation = self._parsed_operations[operation_name]
# 构建请求
url = self._build_request_url(operation, kwargs)
headers = self._build_request_headers(operation)
data = self._build_request_data(operation, kwargs)
# 发送HTTP请求
result = await self._send_http_request(
method=operation["method"],
url=url,
headers=headers,
data=data
)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="CUSTOM_TOOL_ERROR",
execution_time=execution_time
)
def _parse_openapi_schema(self) -> Dict[str, Any]:
"""解析OpenAPI schema"""
operations = {}
if not self.schema_content:
return operations
paths = self.schema_content.get("paths", {})
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "delete", "patch"]:
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
# 解析参数
parameters = {}
if "parameters" in operation:
for param in operation["parameters"]:
param_name = param.get("name")
param_schema = param.get("schema", {})
parameters[param_name] = {
"type": param_schema.get("type", "string"),
"description": param.get("description", ""),
"required": param.get("required", False),
"in": param.get("in", "query"),
**param_schema
}
# 解析请求体
request_body = None
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
if "application/json" in content:
request_body = content["application/json"].get("schema", {})
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": parameters,
"request_body": request_body
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(openapi_type, ParameterType.STRING)
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
"""构建请求URL"""
path = operation["path"]
# 替换路径参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "path" and param_name in params:
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
# 构建完整URL
if self.base_url:
url = urljoin(self.base_url, path.lstrip("/"))
else:
# 从schema中获取服务器URL
servers = self.schema_content.get("servers", [])
if servers:
base_url = servers[0].get("url", "")
url = urljoin(base_url, path.lstrip("/"))
else:
url = path
# 添加查询参数
query_params = {}
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "query" and param_name in params:
query_params[param_name] = params[param_name]
if query_params:
from urllib.parse import urlencode
url += "?" + urlencode(query_params)
return url
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"User-Agent": "CustomTool/1.0"
}
# 添加认证头
if self.auth_type == AuthType.API_KEY:
api_key = self.auth_config.get("api_key")
key_name = self.auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif self.auth_type == AuthType.BEARER_TOKEN:
token = self.auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
if request_body:
# 构建请求体数据
data = {}
properties = request_body.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name in params:
data[prop_name] = params[prop_name]
return data if data else None
return None
async def _send_http_request(
self,
method: str,
url: str,
headers: Dict[str, str],
data: Optional[Dict[str, Any]] = None
) -> Any:
"""发送HTTP请求"""
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs = {
"headers": headers
}
if data and method in ["POST", "PUT", "PATCH"]:
kwargs["json"] = data
async with session.request(method, url, **kwargs) as response:
if response.status >= 400:
error_text = await response.text()
raise Exception(f"HTTP {response.status}: {error_text}")
# 尝试解析JSON响应
try:
return await response.json()
except Exception as e:
return await response.text()
@classmethod
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从URL导入OpenAPI schema创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_url": schema_url,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
# 这里应该异步加载schema为了简化暂时返回空配置
return cls(tool_id, config)
@classmethod
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从schema字典创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_content": schema_dict,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
return cls(tool_id, config)

View File

@@ -0,0 +1,477 @@
"""OpenAPI Schema解析器"""
import json
import yaml
from typing import Dict, Any, List, Optional, Tuple
from urllib.parse import urlparse
import aiohttp
import asyncio
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
def __init__(self):
"""初始化解析器"""
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
"""从URL解析OpenAPI schema
Args:
schema_url: Schema URL
timeout: 超时时间(秒)
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 验证URL格式
parsed_url = urlparse(schema_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False, {}, "无效的URL格式"
# 下载schema
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(schema_url) as response:
if response.status != 200:
return False, {}, f"HTTP错误: {response.status}"
content_type = response.headers.get('content-type', '').lower()
content = await response.text()
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except asyncio.TimeoutError:
return False, {}, "请求超时"
except Exception as e:
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
return False, {}, str(e)
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
"""从内容解析OpenAPI schema
Args:
content: Schema内容
content_type: 内容类型
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
content: 内容字符串
content_type: 内容类型
Returns:
解析后的字典失败返回None
"""
try:
# 根据内容类型解析
if 'json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
else:
# 尝试自动检测格式
try:
return json.loads(content)
except json.JSONDecodeError:
try:
return yaml.safe_load(content)
except yaml.YAMLError:
return None
except Exception as e:
logger.error(f"解析内容失败: {e}")
return None
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
"""验证OpenAPI schema
Args:
schema_dict: Schema字典
Returns:
(是否有效, 错误信息)
"""
try:
# 检查基本结构
if not isinstance(schema_dict, dict):
return False, "Schema必须是JSON对象"
# 检查OpenAPI版本
openapi_version = schema_dict.get("openapi")
if not openapi_version:
return False, "缺少openapi版本字段"
if openapi_version not in self.supported_versions:
return False, f"不支持的OpenAPI版本: {openapi_version}"
# 检查必需字段
required_fields = ["info", "paths"]
for field in required_fields:
if field not in schema_dict:
return False, f"缺少必需字段: {field}"
# 验证info字段
info = schema_dict.get("info", {})
if not isinstance(info, dict):
return False, "info字段必须是对象"
if "title" not in info:
return False, "info.title字段是必需的"
# 验证paths字段
paths = schema_dict.get("paths", {})
if not isinstance(paths, dict):
return False, "paths字段必须是对象"
# 验证至少有一个路径
if not paths:
return False, "至少需要定义一个API路径"
return True, ""
except Exception as e:
return False, f"验证schema时出错: {e}"
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""从schema提取工具信息
Args:
schema_dict: Schema字典
Returns:
工具信息字典
"""
info = schema_dict.get("info", {})
return {
"name": info.get("title", "Custom API Tool"),
"description": info.get("description", ""),
"version": info.get("version", "1.0.0"),
"servers": schema_dict.get("servers", []),
"operations": self._extract_operations(schema_dict)
}
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""提取API操作信息
Args:
schema_dict: Schema字典
Returns:
操作信息字典
"""
operations = {}
paths = schema_dict.get("paths", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method, operation in path_item.items():
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
continue
if not isinstance(operation, dict):
continue
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
# 提取操作信息
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"tags": operation.get("tags", [])
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
operation: 操作定义
Returns:
参数信息字典
"""
parameters = {}
for param in operation.get("parameters", []):
if not isinstance(param, dict):
continue
param_name = param.get("name")
if not param_name:
continue
param_schema = param.get("schema", {})
parameters[param_name] = {
"name": param_name,
"in": param.get("in", "query"),
"description": param.get("description", ""),
"required": param.get("required", False),
"type": param_schema.get("type", "string"),
"format": param_schema.get("format"),
"enum": param_schema.get("enum"),
"default": param_schema.get("default"),
"minimum": param_schema.get("minimum"),
"maximum": param_schema.get("maximum"),
"pattern": param_schema.get("pattern"),
"example": param.get("example") or param_schema.get("example")
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
operation: 操作定义
Returns:
请求体信息如果没有返回None
"""
request_body = operation.get("requestBody")
if not request_body:
return None
content = request_body.get("content", {})
# 优先使用application/json
if "application/json" in content:
schema = content["application/json"].get("schema", {})
elif content:
# 使用第一个可用的内容类型
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema", {})
else:
return None
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
operation: 操作定义
Returns:
响应信息字典
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
if not isinstance(response, dict):
continue
content = response.get("content", {})
schema = None
# 尝试获取响应schema
if "application/json" in content:
schema = content["application/json"].get("schema")
elif content:
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema")
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"content_types": list(content.keys()) if content else []
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
operations: 操作信息字典
Returns:
参数定义列表
"""
parameters = []
# 如果有多个操作,添加操作选择参数
if len(operations) > 1:
parameters.append({
"name": "operation",
"type": "string",
"description": "要执行的操作",
"required": True,
"enum": list(operations.keys())
})
# 收集所有参数(去重)
all_params = {}
for operation_id, operation in operations.items():
# 路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_name not in all_params:
all_params[param_name] = {
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default"),
"minimum": param_info.get("minimum"),
"maximum": param_info.get("maximum"),
"pattern": param_info.get("pattern")
}
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name not in all_params:
all_params[prop_name] = {
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in schema.get("required", []),
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default"),
"minimum": prop_schema.get("minimum"),
"maximum": prop_schema.get("maximum"),
"pattern": prop_schema.get("pattern")
}
# 转换为参数列表
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
Args:
operation: 操作定义
params: 输入参数
Returns:
(是否有效, 错误信息列表)
"""
errors = []
# 验证路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("required", False) and param_name not in params:
errors.append(f"缺少必需参数: {param_name}")
if param_name in params:
value = params[param_name]
param_type = param_info.get("type", "string")
# 类型验证
if not self._validate_parameter_type(value, param_type):
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
# 枚举验证
enum_values = param_info.get("enum")
if enum_values and value not in enum_values:
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
# 验证请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
required_props = schema.get("required", [])
properties = schema.get("properties", {})
for prop_name in required_props:
if prop_name not in params:
errors.append(f"缺少必需的请求体参数: {prop_name}")
for prop_name, value in params.items():
if prop_name in properties:
prop_schema = properties[prop_name]
prop_type = prop_schema.get("type", "string")
if not self._validate_parameter_type(value, prop_type):
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
value: 参数值
expected_type: 期望类型
Returns:
是否类型匹配
"""
if value is None:
return True
type_mapping = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict
}
expected_python_type = type_mapping.get(expected_type)
if expected_python_type:
return isinstance(value, expected_python_type)
return True