Merge #73 into develop from feature/20251219_xjn

feat(tool system): Optimization of the tool system

* feature/20251219_xjn: (1 commits)
  feat(tool system): Optimization of the tool system

Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/73
This commit is contained in:
朱文辉
2025-12-26 19:15:34 +08:00
6 changed files with 300 additions and 25 deletions

View File

@@ -157,8 +157,8 @@ class DateTimeTool(BuiltinTool):
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
from_timezone = kwargs.get("from_timezone", "UTC")
to_timezone = kwargs.get("to_timezone", "UTC")
from_timezone = kwargs.get("from_timezone", "Asia/Shanghai")
to_timezone = kwargs.get("to_timezone", "Asia/Shanghai")
if not input_value:
raise ValueError("input_value 参数是必需的")
@@ -197,7 +197,7 @@ class DateTimeTool(BuiltinTool):
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("to_timezone", "UTC")
timezone_str = kwargs.get("to_timezone", "Asia/Shanghai")
if not input_value:
raise ValueError("input_value 参数是必需的")
@@ -227,7 +227,7 @@ class DateTimeTool(BuiltinTool):
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("from_timezone", "UTC")
timezone_str = kwargs.get("from_timezone", "Asia/Shanghai")
if not input_value:
raise ValueError("input_value 参数是必需的")
@@ -258,7 +258,7 @@ class DateTimeTool(BuiltinTool):
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
calculation = kwargs.get("calculation")
timezone_str = kwargs.get("from_timezone", "UTC")
timezone_str = kwargs.get("from_timezone", "Asia/Shanghai")
if not input_value:
raise ValueError("input_value 参数是必需的")

View File

@@ -29,7 +29,8 @@ class JsonTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
"extract", "insert", "replace", "delete", "parse"]
),
ToolParameter(
name="input_data",
@@ -69,7 +70,25 @@ class JsonTool(BuiltinTool):
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式用于extract操作,如:$.user.name",
description="JSON路径表达式用于extract、insert、replace、delete、parse操作$.user.name或users[0].name",
required=False
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert和replace操作",
required=False
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=False
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=False
)
]
@@ -105,6 +124,14 @@ class JsonTool(BuiltinTool):
result = self._merge_json(input_data, kwargs)
elif operation == "extract":
result = self._extract_json_path(input_data, kwargs)
elif operation == "insert":
result = self._insert_json_value(input_data, kwargs)
elif operation == "replace":
result = self._replace_json_value(input_data, kwargs)
elif operation == "delete":
result = self._delete_json_key(input_data, kwargs)
elif operation == "parse":
result = self._parse_json_value(input_data, kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
@@ -415,6 +442,248 @@ class JsonTool(BuiltinTool):
"extracted_data": None
}
@staticmethod
def _insert_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""插入JSON值"""
json_path = kwargs.get("json_path")
new_value = kwargs.get("new_value")
if not json_path:
raise ValueError("json_path 参数是必需的")
if new_value is None:
raise ValueError("new_value 参数是必需的")
data = json.loads(input_data)
try:
parsed_value = json.loads(new_value)
except (json.JSONDecodeError, TypeError):
parsed_value = new_value
# 解析路径
path_parts = json_path.replace('$.', '').split('.')
try:
# 导航到父节点
current = data
for part in path_parts[:-1]:
if '[' in part and ']' in part:
# 处理数组索引
key, index_str = part.split('[')
index = int(index_str.rstrip(']'))
if key:
current = current[key]
current = current[index]
else:
current = current[part]
# 插入新值
last_part = path_parts[-1]
if '[' in last_part and ']' in last_part:
# 数组操作
key, index_str = last_part.split('[')
index = int(index_str.rstrip(']'))
if key:
current[key][index] = parsed_value
else:
current[index] = parsed_value
else:
# 对象操作
current[last_part] = parsed_value
result_json = json.dumps(data, indent=2, ensure_ascii=False)
return {
"operation": "insert",
"json_path": json_path,
"success": True,
"new_value": parsed_value,
"result_data": result_json
}
except (KeyError, IndexError, TypeError, ValueError) as e:
return {
"operation": "insert",
"json_path": json_path,
"success": False,
"error": str(e),
"result_data": input_data
}
@staticmethod
def _replace_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""替换JSON值中的文本"""
json_path = kwargs.get("json_path")
old_text = kwargs.get("old_text")
new_text = kwargs.get("new_text")
if not json_path:
raise ValueError("json_path 参数是必需的")
if old_text is None:
raise ValueError("old_text 参数是必需的")
if new_text is None:
raise ValueError("new_text 参数是必需的")
data = json.loads(input_data)
# 解析路径
path_parts = json_path.replace('$.', '').split('.')
try:
# 导航到目标值
current = data
for part in path_parts[:-1]:
if '[' in part and ']' in part:
key, index_str = part.split('[')
index = int(index_str.rstrip(']'))
if key:
current = current[key]
current = current[index]
else:
current = current[part]
# 获取并替换值
last_part = path_parts[-1]
if '[' in last_part and ']' in last_part:
key, index_str = last_part.split('[')
index = int(index_str.rstrip(']'))
if key:
original_value = str(current[key][index])
current[key][index] = original_value.replace(old_text, new_text)
else:
original_value = str(current[index])
current[index] = original_value.replace(old_text, new_text)
else:
original_value = str(current[last_part])
current[last_part] = original_value.replace(old_text, new_text)
result_json = json.dumps(data, indent=2, ensure_ascii=False)
return {
"operation": "replace",
"json_path": json_path,
"success": True,
"old_text": old_text,
"new_text": new_text,
"original_value": original_value,
"result_data": result_json
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "replace",
"json_path": json_path,
"success": False,
"error": str(e),
"result_data": input_data
}
@staticmethod
def _delete_json_key(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""删除JSON键值对"""
json_path = kwargs.get("json_path")
if not json_path:
raise ValueError("json_path 参数是必需的")
data = json.loads(input_data)
# 解析路径
path_parts = json_path.replace('$.', '').split('.')
try:
# 导航到父节点
current = data
for part in path_parts[:-1]:
if '[' in part and ']' in part:
key, index_str = part.split('[')
index = int(index_str.rstrip(']'))
if key:
current = current[key]
current = current[index]
else:
current = current[part]
# 删除目标键或元素
last_part = path_parts[-1]
deleted_value = None
if '[' in last_part and ']' in last_part:
# 数组操作
key, index_str = last_part.split('[')
index = int(index_str.rstrip(']'))
if key:
deleted_value = current[key].pop(index)
else:
deleted_value = current.pop(index)
else:
# 对象操作
deleted_value = current.pop(last_part)
result_json = json.dumps(data, indent=2, ensure_ascii=False)
return {
"operation": "delete",
"json_path": json_path,
"success": True,
"deleted_value": deleted_value,
"result_data": result_json
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "delete",
"json_path": json_path,
"success": False,
"error": str(e),
"result_data": input_data
}
@staticmethod
def _parse_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""解析获取JSON值"""
json_path = kwargs.get("json_path")
if not json_path:
raise ValueError("json_path 参数是必需的")
data = json.loads(input_data)
# 解析路径
path_parts = json_path.replace('$.', '').split('.')
try:
# 导航到目标值
current = data
for part in path_parts:
if '[' in part and ']' in part:
# 处理数组索引
key, index_str = part.split('[')
index = int(index_str.rstrip(']'))
if key:
current = current[key]
current = current[index]
else:
current = current[part]
return {
"operation": "parse",
"json_path": json_path,
"success": True,
"value": current,
"value_type": type(current).__name__,
"value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current)
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "parse",
"json_path": json_path,
"success": False,
"error": str(e),
"value": None
}
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
"""分析JSON结构"""
if isinstance(data, dict):

View File

@@ -379,9 +379,8 @@ class MCPClient:
Returns:
响应数据
"""
request_id = str(request_data["id"])
if self.connection_type == "websocket":
request_id = str(request_data["id"])
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
@@ -423,12 +422,19 @@ class MCPClient:
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
if response.status == 200:
return await response.json()
else:
async with self._session.post(
self.server_url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as root_response:
if root_response.status != 200:
error_text = await root_response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")

View File

@@ -36,7 +36,7 @@ class ToolRepository:
query = query.filter(ToolConfig.status == status.value)
if is_enabled is not None:
query = query.filter(ToolConfig.is_enabled == is_enabled)
query = query.order_by(ToolConfig.created_at.desc())
return query.all()
@staticmethod

View File

@@ -225,7 +225,7 @@ class CustomToolCreateRequest(BaseModel):
class ParseSchemaRequest(BaseModel):
"""解析Schema请求"""
schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容")
schema_content: Optional[str] = Field(None, description="OpenAPI schema内容")
schema_url: Optional[str] = Field(None, description="OpenAPI schema URL")

View File

@@ -461,7 +461,7 @@ class ToolService:
parser = OpenAPISchemaParser()
if config.get("schema_content"):
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), "application/json")
success, schema, _ = parser.parse_from_content(config["schema_content"], "application/json")
else:
success, schema, _ = parser.parse_from_url(config["schema_url"])
@@ -515,7 +515,7 @@ class ToolService:
parser = OpenAPISchemaParser()
if config.get("schema_content"):
success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]),
success, schema, _ = parser.parse_from_content(config["schema_content"],
"application/json")
else:
success, schema, _ = parser.parse_from_url(config["schema_url"])
@@ -686,7 +686,7 @@ class ToolService:
if await client.connect():
try:
tools = await client.list_tools()
# tools = await client.list_tools()
await client.disconnect()
# 更新连接状态
@@ -701,7 +701,8 @@ class ToolService:
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
"details": {"server_url": mcp_config.server_url}
}
except Exception as e:
await client.disconnect()
@@ -739,16 +740,15 @@ class ToolService:
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
@staticmethod
async def parse_openapi_schema(schema_data: Dict[str, Any] = None, schema_url: str = None) -> Dict[str, Any]:
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
"""解析OpenAPI schema获取接口信息"""
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
# 使用现有的解析器
if schema_data:
success, schema, error = parser.parse_from_content(json.dumps(schema_data), "application/json")
success, schema, error = parser.parse_from_content(schema_data, "application/json")
elif schema_url:
success, schema, error = await parser.parse_from_url(schema_url)
else: