feat(apikey system): tool system development
This commit is contained in:
11
api/app/core/tools/custom/__init__.py
Normal file
11
api/app/core/tools/custom/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""自定义工具模块"""
|
||||
|
||||
from .base import CustomTool
|
||||
from .schema_parser import OpenAPISchemaParser
|
||||
from .auth_manager import AuthManager
|
||||
|
||||
__all__ = [
|
||||
"CustomTool",
|
||||
"OpenAPISchemaParser",
|
||||
"AuthManager"
|
||||
]
|
||||
525
api/app/core/tools/custom/auth_manager.py
Normal file
525
api/app/core/tools/custom/auth_manager.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""认证管理器 - 处理自定义工具的认证配置"""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import AuthType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""认证管理器 - 支持多种认证方式"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化认证管理器"""
|
||||
self.supported_auth_types = [
|
||||
AuthType.NONE,
|
||||
AuthType.API_KEY,
|
||||
AuthType.BEARER_TOKEN
|
||||
]
|
||||
|
||||
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
if auth_type not in self.supported_auth_types:
|
||||
return False, f"不支持的认证类型: {auth_type}"
|
||||
|
||||
if auth_type == AuthType.NONE:
|
||||
return True, ""
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._validate_api_key_config(auth_config)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._validate_bearer_token_config(auth_config)
|
||||
|
||||
return False, "未知的认证类型"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证认证配置时出错: {e}"
|
||||
|
||||
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证API Key认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
if not api_key:
|
||||
return False, "API Key不能为空"
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
return False, "API Key必须是字符串"
|
||||
|
||||
# 验证key名称
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if not isinstance(key_name, str):
|
||||
return False, "API Key名称必须是字符串"
|
||||
|
||||
# 验证位置
|
||||
key_location = auth_config.get("location", "header")
|
||||
if key_location not in ["header", "query", "cookie"]:
|
||||
return False, "API Key位置必须是 header、query 或 cookie"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证Bearer Token认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
if not token:
|
||||
return False, "Bearer Token不能为空"
|
||||
|
||||
if not isinstance(token, str):
|
||||
return False, "Bearer Token必须是字符串"
|
||||
|
||||
return True, ""
|
||||
|
||||
def apply_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用认证到请求
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
try:
|
||||
if auth_type == AuthType.NONE:
|
||||
return url, headers, params
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._apply_api_key_auth(auth_config, url, headers, params)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._apply_bearer_token_auth(auth_config, url, headers, params)
|
||||
|
||||
else:
|
||||
logger.warning(f"不支持的认证类型: {auth_type}")
|
||||
return url, headers, params
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"应用认证时出错: {e}")
|
||||
return url, headers, params
|
||||
|
||||
def _apply_api_key_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用API Key认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
location = auth_config.get("location", "header")
|
||||
|
||||
if location == "header":
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif location == "query":
|
||||
# 添加到URL查询参数
|
||||
separator = "&" if "?" in url else "?"
|
||||
encoded_key = quote(str(api_key))
|
||||
url += f"{separator}{key_name}={encoded_key}"
|
||||
|
||||
elif location == "cookie":
|
||||
# 添加到Cookie头
|
||||
cookie_value = f"{key_name}={api_key}"
|
||||
if "Cookie" in headers:
|
||||
headers["Cookie"] += f"; {cookie_value}"
|
||||
else:
|
||||
headers["Cookie"] = cookie_value
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def _apply_bearer_token_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用Bearer Token认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""加密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
encryption_key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的认证配置
|
||||
"""
|
||||
try:
|
||||
encrypted_config = auth_config.copy()
|
||||
|
||||
# 需要加密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in encrypted_config:
|
||||
value = encrypted_config[field]
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_value = self._encrypt_string(value, encryption_key)
|
||||
encrypted_config[field] = encrypted_value
|
||||
encrypted_config[f"{field}_encrypted"] = True
|
||||
|
||||
return encrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密认证配置失败: {e}")
|
||||
return auth_config
|
||||
|
||||
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""解密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
encrypted_config: 加密的认证配置
|
||||
encryption_key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的认证配置
|
||||
"""
|
||||
try:
|
||||
decrypted_config = encrypted_config.copy()
|
||||
|
||||
# 需要解密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
|
||||
encrypted_value = decrypted_config[field]
|
||||
if isinstance(encrypted_value, str) and encrypted_value:
|
||||
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
|
||||
decrypted_config[field] = decrypted_value
|
||||
# 移除加密标记
|
||||
decrypted_config.pop(f"{field}_encrypted", None)
|
||||
|
||||
return decrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密认证配置失败: {e}")
|
||||
return encrypted_config
|
||||
|
||||
def _encrypt_string(self, value: str, key: str) -> str:
|
||||
"""加密字符串
|
||||
|
||||
Args:
|
||||
value: 要加密的字符串
|
||||
key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的字符串(Base64编码)
|
||||
"""
|
||||
try:
|
||||
# 使用HMAC-SHA256进行简单加密
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
# 生成HMAC
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
signature = hmac_obj.hexdigest()
|
||||
|
||||
# 组合原始值和签名,然后Base64编码
|
||||
combined = f"{value}:{signature}"
|
||||
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
|
||||
|
||||
return encrypted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密字符串失败: {e}")
|
||||
return value
|
||||
|
||||
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
|
||||
"""解密字符串
|
||||
|
||||
Args:
|
||||
encrypted_value: 加密的字符串
|
||||
key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的字符串
|
||||
"""
|
||||
try:
|
||||
# Base64解码
|
||||
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
|
||||
|
||||
# 分离原始值和签名
|
||||
if ':' not in decoded:
|
||||
return encrypted_value # 可能不是加密的值
|
||||
|
||||
value, signature = decoded.rsplit(':', 1)
|
||||
|
||||
# 验证签名
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
expected_signature = hmac_obj.hexdigest()
|
||||
|
||||
if signature == expected_signature:
|
||||
return value
|
||||
else:
|
||||
logger.warning("解密时签名验证失败")
|
||||
return encrypted_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密字符串失败: {e}")
|
||||
return encrypted_value
|
||||
|
||||
def test_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL(可选)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 如果没有测试URL,只验证配置
|
||||
if not test_url:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置有效",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建测试请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置测试成功",
|
||||
"auth_type": auth_type.value,
|
||||
"test_url": test_url,
|
||||
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
|
||||
"has_auth_header": "Authorization" in headers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"auth_type": auth_type.value if auth_type else "unknown"
|
||||
}
|
||||
|
||||
async def test_authentication_with_request(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str,
|
||||
timeout: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""通过实际HTTP请求测试认证
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
# 发送测试请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(test_url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
# 根据状态码判断认证是否成功
|
||||
if status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证测试成功",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 401:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 401 Unauthorized",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 403:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 403 Forbidden",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"请求成功,状态码: {status_code}",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"网络请求失败: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"测试认证时出错: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
|
||||
"""获取认证配置模板
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
|
||||
Returns:
|
||||
配置模板
|
||||
"""
|
||||
templates = {
|
||||
AuthType.NONE: {},
|
||||
|
||||
AuthType.API_KEY: {
|
||||
"api_key": "",
|
||||
"key_name": "X-API-Key",
|
||||
"location": "header", # header, query, cookie
|
||||
"description": "API Key认证配置"
|
||||
},
|
||||
|
||||
AuthType.BEARER_TOKEN: {
|
||||
"token": "",
|
||||
"description": "Bearer Token认证配置"
|
||||
}
|
||||
}
|
||||
|
||||
return templates.get(auth_type, {})
|
||||
|
||||
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""遮蔽认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
遮蔽敏感信息后的配置
|
||||
"""
|
||||
masked_config = auth_config.copy()
|
||||
|
||||
# 需要遮蔽的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in masked_config:
|
||||
value = masked_config[field]
|
||||
if isinstance(value, str) and len(value) > 4:
|
||||
# 只显示前2位和后2位
|
||||
masked_config[field] = f"{value[:2]}***{value[-2:]}"
|
||||
elif isinstance(value, str) and value:
|
||||
masked_config[field] = "***"
|
||||
|
||||
return masked_config
|
||||
318
api/app/core/tools/custom/base.py
Normal file
318
api/app/core/tools/custom/base.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""自定义工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from app.models.tool_model import ToolType, AuthType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
"""自定义工具 - 基于OpenAPI schema的工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化自定义工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.schema_content = config.get("schema_content", {})
|
||||
self.schema_url = config.get("schema_url")
|
||||
self.auth_type = AuthType(config.get("auth_type", "none"))
|
||||
self.auth_config = config.get("auth_config", {})
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
# 解析schema
|
||||
self._parsed_operations = self._parse_openapi_schema()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
|
||||
return f"custom_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("description", "自定义API工具")
|
||||
return "自定义API工具"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.CUSTOM
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加操作选择参数
|
||||
if len(self._parsed_operations) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="要执行的操作",
|
||||
required=True,
|
||||
enum=list(self._parsed_operations.keys())
|
||||
))
|
||||
|
||||
# 添加通用参数(基于第一个操作的参数)
|
||||
if self._parsed_operations:
|
||||
first_operation = next(iter(self._parsed_operations.values()))
|
||||
for param_name, param_info in first_operation.get("parameters", {}).items():
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=self._convert_openapi_type(param_info.get("type", "string")),
|
||||
description=param_info.get("description", ""),
|
||||
required=param_info.get("required", False),
|
||||
default=param_info.get("default"),
|
||||
enum=param_info.get("enum"),
|
||||
minimum=param_info.get("minimum"),
|
||||
maximum=param_info.get("maximum"),
|
||||
pattern=param_info.get("pattern")
|
||||
))
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行自定义工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确定要执行的操作
|
||||
operation_name = kwargs.get("operation")
|
||||
if not operation_name and len(self._parsed_operations) == 1:
|
||||
operation_name = next(iter(self._parsed_operations.keys()))
|
||||
|
||||
if not operation_name or operation_name not in self._parsed_operations:
|
||||
raise ValueError(f"无效的操作: {operation_name}")
|
||||
|
||||
operation = self._parsed_operations[operation_name]
|
||||
|
||||
# 构建请求
|
||||
url = self._build_request_url(operation, kwargs)
|
||||
headers = self._build_request_headers(operation)
|
||||
data = self._build_request_data(operation, kwargs)
|
||||
|
||||
# 发送HTTP请求
|
||||
result = await self._send_http_request(
|
||||
method=operation["method"],
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=data
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="CUSTOM_TOOL_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _parse_openapi_schema(self) -> Dict[str, Any]:
|
||||
"""解析OpenAPI schema"""
|
||||
operations = {}
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() in ["get", "post", "put", "delete", "patch"]:
|
||||
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
|
||||
|
||||
# 解析参数
|
||||
parameters = {}
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
param_name = param.get("name")
|
||||
param_schema = param.get("schema", {})
|
||||
parameters[param_name] = {
|
||||
"type": param_schema.get("type", "string"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"in": param.get("in", "query"),
|
||||
**param_schema
|
||||
}
|
||||
|
||||
# 解析请求体
|
||||
request_body = None
|
||||
if "requestBody" in operation:
|
||||
content = operation["requestBody"].get("content", {})
|
||||
if "application/json" in content:
|
||||
request_body = content["application/json"].get("schema", {})
|
||||
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": parameters,
|
||||
"request_body": request_body
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
|
||||
"""转换OpenAPI类型到内部类型"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(openapi_type, ParameterType.STRING)
|
||||
|
||||
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
|
||||
"""构建请求URL"""
|
||||
path = operation["path"]
|
||||
|
||||
# 替换路径参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "path" and param_name in params:
|
||||
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
|
||||
|
||||
# 构建完整URL
|
||||
if self.base_url:
|
||||
url = urljoin(self.base_url, path.lstrip("/"))
|
||||
else:
|
||||
# 从schema中获取服务器URL
|
||||
servers = self.schema_content.get("servers", [])
|
||||
if servers:
|
||||
base_url = servers[0].get("url", "")
|
||||
url = urljoin(base_url, path.lstrip("/"))
|
||||
else:
|
||||
url = path
|
||||
|
||||
# 添加查询参数
|
||||
query_params = {}
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "query" and param_name in params:
|
||||
query_params[param_name] = params[param_name]
|
||||
|
||||
if query_params:
|
||||
from urllib.parse import urlencode
|
||||
url += "?" + urlencode(query_params)
|
||||
|
||||
return url
|
||||
|
||||
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CustomTool/1.0"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
if self.auth_type == AuthType.API_KEY:
|
||||
api_key = self.auth_config.get("api_key")
|
||||
key_name = self.auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif self.auth_type == AuthType.BEARER_TOKEN:
|
||||
token = self.auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return headers
|
||||
|
||||
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建请求数据"""
|
||||
if operation["method"] in ["POST", "PUT", "PATCH"]:
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
# 构建请求体数据
|
||||
data = {}
|
||||
properties = request_body.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name in params:
|
||||
data[prop_name] = params[prop_name]
|
||||
|
||||
return data if data else None
|
||||
|
||||
return None
|
||||
|
||||
async def _send_http_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""发送HTTP请求"""
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
kwargs = {
|
||||
"headers": headers
|
||||
}
|
||||
|
||||
if data and method in ["POST", "PUT", "PATCH"]:
|
||||
kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
return await response.text()
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从URL导入OpenAPI schema创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_url": schema_url,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
# 这里应该异步加载schema,为了简化暂时返回空配置
|
||||
return cls(tool_id, config)
|
||||
|
||||
@classmethod
|
||||
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从schema字典创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_content": schema_dict,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
return cls(tool_id, config)
|
||||
477
api/app/core/tools/custom/schema_parser.py
Normal file
477
api/app/core/tools/custom/schema_parser.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""OpenAPI Schema解析器"""
|
||||
import json
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化解析器"""
|
||||
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
||||
|
||||
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从URL解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_url: Schema URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 验证URL格式
|
||||
parsed_url = urlparse(schema_url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
return False, {}, "无效的URL格式"
|
||||
|
||||
# 下载schema
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(schema_url) as response:
|
||||
if response.status != 200:
|
||||
return False, {}, f"HTTP错误: {response.status}"
|
||||
|
||||
content_type = response.headers.get('content-type', '').lower()
|
||||
content = await response.text()
|
||||
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, {}, "请求超时"
|
||||
except Exception as e:
|
||||
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从内容解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
content: Schema内容
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema内容失败: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析内容为字典
|
||||
|
||||
Args:
|
||||
content: 内容字符串
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 根据内容类型解析
|
||||
if 'json' in content_type:
|
||||
return json.loads(content)
|
||||
elif 'yaml' in content_type or 'yml' in content_type:
|
||||
return yaml.safe_load(content)
|
||||
else:
|
||||
# 尝试自动检测格式
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析内容失败: {e}")
|
||||
return None
|
||||
|
||||
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 检查基本结构
|
||||
if not isinstance(schema_dict, dict):
|
||||
return False, "Schema必须是JSON对象"
|
||||
|
||||
# 检查OpenAPI版本
|
||||
openapi_version = schema_dict.get("openapi")
|
||||
if not openapi_version:
|
||||
return False, "缺少openapi版本字段"
|
||||
|
||||
if openapi_version not in self.supported_versions:
|
||||
return False, f"不支持的OpenAPI版本: {openapi_version}"
|
||||
|
||||
# 检查必需字段
|
||||
required_fields = ["info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in schema_dict:
|
||||
return False, f"缺少必需字段: {field}"
|
||||
|
||||
# 验证info字段
|
||||
info = schema_dict.get("info", {})
|
||||
if not isinstance(info, dict):
|
||||
return False, "info字段必须是对象"
|
||||
|
||||
if "title" not in info:
|
||||
return False, "info.title字段是必需的"
|
||||
|
||||
# 验证paths字段
|
||||
paths = schema_dict.get("paths", {})
|
||||
if not isinstance(paths, dict):
|
||||
return False, "paths字段必须是对象"
|
||||
|
||||
# 验证至少有一个路径
|
||||
if not paths:
|
||||
return False, "至少需要定义一个API路径"
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证schema时出错: {e}"
|
||||
|
||||
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从schema提取工具信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
工具信息字典
|
||||
"""
|
||||
info = schema_dict.get("info", {})
|
||||
|
||||
return {
|
||||
"name": info.get("title", "Custom API Tool"),
|
||||
"description": info.get("description", ""),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
"servers": schema_dict.get("servers", []),
|
||||
"operations": self._extract_operations(schema_dict)
|
||||
}
|
||||
|
||||
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取API操作信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
操作信息字典
|
||||
"""
|
||||
operations = {}
|
||||
paths = schema_dict.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
continue
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
|
||||
|
||||
# 提取操作信息
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
"responses": self._extract_responses(operation),
|
||||
"tags": operation.get("tags", [])
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
参数信息字典
|
||||
"""
|
||||
parameters = {}
|
||||
|
||||
for param in operation.get("parameters", []):
|
||||
if not isinstance(param, dict):
|
||||
continue
|
||||
|
||||
param_name = param.get("name")
|
||||
if not param_name:
|
||||
continue
|
||||
|
||||
param_schema = param.get("schema", {})
|
||||
|
||||
parameters[param_name] = {
|
||||
"name": param_name,
|
||||
"in": param.get("in", "query"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"type": param_schema.get("type", "string"),
|
||||
"format": param_schema.get("format"),
|
||||
"enum": param_schema.get("enum"),
|
||||
"default": param_schema.get("default"),
|
||||
"minimum": param_schema.get("minimum"),
|
||||
"maximum": param_schema.get("maximum"),
|
||||
"pattern": param_schema.get("pattern"),
|
||||
"example": param.get("example") or param_schema.get("example")
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""提取请求体信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
请求体信息,如果没有返回None
|
||||
"""
|
||||
request_body = operation.get("requestBody")
|
||||
if not request_body:
|
||||
return None
|
||||
|
||||
content = request_body.get("content", {})
|
||||
|
||||
# 优先使用application/json
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
elif content:
|
||||
# 使用第一个可用的内容类型
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema", {})
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"description": request_body.get("description", ""),
|
||||
"required": request_body.get("required", False),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys())
|
||||
}
|
||||
|
||||
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取响应信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
响应信息字典
|
||||
"""
|
||||
responses = {}
|
||||
|
||||
for status_code, response in operation.get("responses", {}).items():
|
||||
if not isinstance(response, dict):
|
||||
continue
|
||||
|
||||
content = response.get("content", {})
|
||||
schema = None
|
||||
|
||||
# 尝试获取响应schema
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema")
|
||||
elif content:
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema")
|
||||
|
||||
responses[status_code] = {
|
||||
"description": response.get("description", ""),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys()) if content else []
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成工具参数定义
|
||||
|
||||
Args:
|
||||
operations: 操作信息字典
|
||||
|
||||
Returns:
|
||||
参数定义列表
|
||||
"""
|
||||
parameters = []
|
||||
|
||||
# 如果有多个操作,添加操作选择参数
|
||||
if len(operations) > 1:
|
||||
parameters.append({
|
||||
"name": "operation",
|
||||
"type": "string",
|
||||
"description": "要执行的操作",
|
||||
"required": True,
|
||||
"enum": list(operations.keys())
|
||||
})
|
||||
|
||||
# 收集所有参数(去重)
|
||||
all_params = {}
|
||||
|
||||
for operation_id, operation in operations.items():
|
||||
# 路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_name not in all_params:
|
||||
all_params[param_name] = {
|
||||
"name": param_name,
|
||||
"type": param_info.get("type", "string"),
|
||||
"description": param_info.get("description", ""),
|
||||
"required": param_info.get("required", False),
|
||||
"enum": param_info.get("enum"),
|
||||
"default": param_info.get("default"),
|
||||
"minimum": param_info.get("minimum"),
|
||||
"maximum": param_info.get("maximum"),
|
||||
"pattern": param_info.get("pattern")
|
||||
}
|
||||
|
||||
# 请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name not in all_params:
|
||||
all_params[prop_name] = {
|
||||
"name": prop_name,
|
||||
"type": prop_schema.get("type", "string"),
|
||||
"description": prop_schema.get("description", ""),
|
||||
"required": prop_name in schema.get("required", []),
|
||||
"enum": prop_schema.get("enum"),
|
||||
"default": prop_schema.get("default"),
|
||||
"minimum": prop_schema.get("minimum"),
|
||||
"maximum": prop_schema.get("maximum"),
|
||||
"pattern": prop_schema.get("pattern")
|
||||
}
|
||||
|
||||
# 转换为参数列表
|
||||
parameters.extend(all_params.values())
|
||||
|
||||
return parameters
|
||||
|
||||
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
params: 输入参数
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 验证路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("required", False) and param_name not in params:
|
||||
errors.append(f"缺少必需参数: {param_name}")
|
||||
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
param_type = param_info.get("type", "string")
|
||||
|
||||
# 类型验证
|
||||
if not self._validate_parameter_type(value, param_type):
|
||||
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
|
||||
|
||||
# 枚举验证
|
||||
enum_values = param_info.get("enum")
|
||||
if enum_values and value not in enum_values:
|
||||
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
|
||||
|
||||
# 验证请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
required_props = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name in required_props:
|
||||
if prop_name not in params:
|
||||
errors.append(f"缺少必需的请求体参数: {prop_name}")
|
||||
|
||||
for prop_name, value in params.items():
|
||||
if prop_name in properties:
|
||||
prop_schema = properties[prop_name]
|
||||
prop_type = prop_schema.get("type", "string")
|
||||
|
||||
if not self._validate_parameter_type(value, prop_type):
|
||||
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""验证参数类型
|
||||
|
||||
Args:
|
||||
value: 参数值
|
||||
expected_type: 期望类型
|
||||
|
||||
Returns:
|
||||
是否类型匹配
|
||||
"""
|
||||
if value is None:
|
||||
return True
|
||||
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
|
||||
expected_python_type = type_mapping.get(expected_type)
|
||||
if expected_python_type:
|
||||
return isinstance(value, expected_python_type)
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user