From 9fb7d7d0590a225a162424592953ad1381e3c3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Fri, 26 Dec 2025 19:11:20 +0800 Subject: [PATCH] feat(tool system): Optimization of the tool system 1. Optimization of the JSON tool, add insert, replace, delete, parse 2. Optimization of the mcp test_connection 3. tool list desc 4. datetime_tool default timezone set Asia/Shanghai --- api/app/core/tools/builtin/datetime_tool.py | 10 +- api/app/core/tools/builtin/json_tool.py | 273 +++++++++++++++++++- api/app/core/tools/mcp/client.py | 22 +- api/app/repositories/tool_repository.py | 2 +- api/app/schemas/tool_schema.py | 2 +- api/app/services/tool_service.py | 16 +- 6 files changed, 300 insertions(+), 25 deletions(-) diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 9e5ab9f6..647914b2 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -157,8 +157,8 @@ class DateTimeTool(BuiltinTool): input_value = kwargs.get("input_value") input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") - from_timezone = kwargs.get("from_timezone", "UTC") - to_timezone = kwargs.get("to_timezone", "UTC") + from_timezone = kwargs.get("from_timezone", "Asia/Shanghai") + to_timezone = kwargs.get("to_timezone", "Asia/Shanghai") if not input_value: raise ValueError("input_value 参数是必需的") @@ -197,7 +197,7 @@ class DateTimeTool(BuiltinTool): """时间戳转日期时间""" input_value = kwargs.get("input_value") output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") - timezone_str = kwargs.get("to_timezone", "UTC") + timezone_str = kwargs.get("to_timezone", "Asia/Shanghai") if not input_value: raise ValueError("input_value 参数是必需的") @@ -227,7 +227,7 @@ class DateTimeTool(BuiltinTool): """日期时间转时间戳""" input_value = kwargs.get("input_value") input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") - timezone_str = kwargs.get("from_timezone", "UTC") + timezone_str = kwargs.get("from_timezone", "Asia/Shanghai") if not input_value: raise ValueError("input_value 参数是必需的") @@ -258,7 +258,7 @@ class DateTimeTool(BuiltinTool): input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") calculation = kwargs.get("calculation") - timezone_str = kwargs.get("from_timezone", "UTC") + timezone_str = kwargs.get("from_timezone", "Asia/Shanghai") if not input_value: raise ValueError("input_value 参数是必需的") diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py index d2b73bba..62cd98d3 100644 --- a/api/app/core/tools/builtin/json_tool.py +++ b/api/app/core/tools/builtin/json_tool.py @@ -29,7 +29,8 @@ class JsonTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"] + enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", + "extract", "insert", "replace", "delete", "parse"] ), ToolParameter( name="input_data", @@ -69,7 +70,25 @@ class JsonTool(BuiltinTool): ToolParameter( name="json_path", type=ParameterType.STRING, - description="JSON路径表达式(用于extract操作,如:$.user.name)", + description="JSON路径表达式(用于extract、insert、replace、delete、parse操作,如:$.user.name或users[0].name)", + required=False + ), + ToolParameter( + name="new_value", + type=ParameterType.STRING, + description="新值(用于insert和replace操作)", + required=False + ), + ToolParameter( + name="old_text", + type=ParameterType.STRING, + description="要替换的原文本(用于replace操作)", + required=False + ), + ToolParameter( + name="new_text", + type=ParameterType.STRING, + description="替换后的新文本(用于replace操作)", required=False ) ] @@ -105,6 +124,14 @@ class JsonTool(BuiltinTool): result = self._merge_json(input_data, kwargs) elif operation == "extract": result = self._extract_json_path(input_data, kwargs) + elif operation == "insert": + result = self._insert_json_value(input_data, kwargs) + elif operation == "replace": + result = self._replace_json_value(input_data, kwargs) + elif operation == "delete": + result = self._delete_json_key(input_data, kwargs) + elif operation == "parse": + result = self._parse_json_value(input_data, kwargs) else: raise ValueError(f"不支持的操作类型: {operation}") @@ -415,6 +442,248 @@ class JsonTool(BuiltinTool): "extracted_data": None } + @staticmethod + def _insert_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """插入JSON值""" + json_path = kwargs.get("json_path") + new_value = kwargs.get("new_value") + + if not json_path: + raise ValueError("json_path 参数是必需的") + if new_value is None: + raise ValueError("new_value 参数是必需的") + + data = json.loads(input_data) + + try: + parsed_value = json.loads(new_value) + except (json.JSONDecodeError, TypeError): + parsed_value = new_value + + # 解析路径 + path_parts = json_path.replace('$.', '').split('.') + + try: + # 导航到父节点 + current = data + for part in path_parts[:-1]: + if '[' in part and ']' in part: + # 处理数组索引 + key, index_str = part.split('[') + index = int(index_str.rstrip(']')) + if key: + current = current[key] + current = current[index] + else: + current = current[part] + + # 插入新值 + last_part = path_parts[-1] + if '[' in last_part and ']' in last_part: + # 数组操作 + key, index_str = last_part.split('[') + index = int(index_str.rstrip(']')) + if key: + current[key][index] = parsed_value + else: + current[index] = parsed_value + else: + # 对象操作 + current[last_part] = parsed_value + + result_json = json.dumps(data, indent=2, ensure_ascii=False) + + return { + "operation": "insert", + "json_path": json_path, + "success": True, + "new_value": parsed_value, + "result_data": result_json + } + + except (KeyError, IndexError, TypeError, ValueError) as e: + return { + "operation": "insert", + "json_path": json_path, + "success": False, + "error": str(e), + "result_data": input_data + } + + @staticmethod + def _replace_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """替换JSON值中的文本""" + json_path = kwargs.get("json_path") + old_text = kwargs.get("old_text") + new_text = kwargs.get("new_text") + + if not json_path: + raise ValueError("json_path 参数是必需的") + if old_text is None: + raise ValueError("old_text 参数是必需的") + if new_text is None: + raise ValueError("new_text 参数是必需的") + + data = json.loads(input_data) + + # 解析路径 + path_parts = json_path.replace('$.', '').split('.') + + try: + # 导航到目标值 + current = data + for part in path_parts[:-1]: + if '[' in part and ']' in part: + key, index_str = part.split('[') + index = int(index_str.rstrip(']')) + if key: + current = current[key] + current = current[index] + else: + current = current[part] + + # 获取并替换值 + last_part = path_parts[-1] + if '[' in last_part and ']' in last_part: + key, index_str = last_part.split('[') + index = int(index_str.rstrip(']')) + if key: + original_value = str(current[key][index]) + current[key][index] = original_value.replace(old_text, new_text) + else: + original_value = str(current[index]) + current[index] = original_value.replace(old_text, new_text) + else: + original_value = str(current[last_part]) + current[last_part] = original_value.replace(old_text, new_text) + + result_json = json.dumps(data, indent=2, ensure_ascii=False) + + return { + "operation": "replace", + "json_path": json_path, + "success": True, + "old_text": old_text, + "new_text": new_text, + "original_value": original_value, + "result_data": result_json + } + + except (KeyError, IndexError, TypeError) as e: + return { + "operation": "replace", + "json_path": json_path, + "success": False, + "error": str(e), + "result_data": input_data + } + + @staticmethod + def _delete_json_key(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """删除JSON键值对""" + json_path = kwargs.get("json_path") + + if not json_path: + raise ValueError("json_path 参数是必需的") + + data = json.loads(input_data) + + # 解析路径 + path_parts = json_path.replace('$.', '').split('.') + + try: + # 导航到父节点 + current = data + for part in path_parts[:-1]: + if '[' in part and ']' in part: + key, index_str = part.split('[') + index = int(index_str.rstrip(']')) + if key: + current = current[key] + current = current[index] + else: + current = current[part] + + # 删除目标键或元素 + last_part = path_parts[-1] + deleted_value = None + + if '[' in last_part and ']' in last_part: + # 数组操作 + key, index_str = last_part.split('[') + index = int(index_str.rstrip(']')) + if key: + deleted_value = current[key].pop(index) + else: + deleted_value = current.pop(index) + else: + # 对象操作 + deleted_value = current.pop(last_part) + + result_json = json.dumps(data, indent=2, ensure_ascii=False) + + return { + "operation": "delete", + "json_path": json_path, + "success": True, + "deleted_value": deleted_value, + "result_data": result_json + } + + except (KeyError, IndexError, TypeError) as e: + return { + "operation": "delete", + "json_path": json_path, + "success": False, + "error": str(e), + "result_data": input_data + } + + @staticmethod + def _parse_json_value(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """解析获取JSON值""" + json_path = kwargs.get("json_path") + + if not json_path: + raise ValueError("json_path 参数是必需的") + + data = json.loads(input_data) + + # 解析路径 + path_parts = json_path.replace('$.', '').split('.') + + try: + # 导航到目标值 + current = data + for part in path_parts: + if '[' in part and ']' in part: + # 处理数组索引 + key, index_str = part.split('[') + index = int(index_str.rstrip(']')) + if key: + current = current[key] + current = current[index] + else: + current = current[part] + + return { + "operation": "parse", + "json_path": json_path, + "success": True, + "value": current, + "value_type": type(current).__name__, + "value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current) + } + + except (KeyError, IndexError, TypeError) as e: + return { + "operation": "parse", + "json_path": json_path, + "success": False, + "error": str(e), + "value": None + } + def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]: """分析JSON结构""" if isinstance(data, dict): diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 997e6e84..2e37f2b1 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -379,9 +379,8 @@ class MCPClient: Returns: 响应数据 """ - request_id = str(request_data["id"]) - if self.connection_type == "websocket": + request_id = str(request_data["id"]) return await self._send_websocket_request(request_data, request_id, timeout) else: return await self._send_http_request(request_data, timeout) @@ -423,12 +422,19 @@ class MCPClient: json=request_data, timeout=aiohttp.ClientTimeout(total=timeout) ) as response: - - if response.status != 200: - error_text = await response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() + if response.status == 200: + return await response.json() + else: + async with self._session.post( + self.server_url, + json=request_data, + timeout=aiohttp.ClientTimeout(total=timeout) + ) as root_response: + if root_response.status != 200: + error_text = await root_response.text() + raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") + + return await response.json() except aiohttp.ClientError as e: raise MCPConnectionError(f"HTTP请求失败: {e}") diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index bc8db683..dc78e761 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -36,7 +36,7 @@ class ToolRepository: query = query.filter(ToolConfig.status == status.value) if is_enabled is not None: query = query.filter(ToolConfig.is_enabled == is_enabled) - + query = query.order_by(ToolConfig.created_at.desc()) return query.all() @staticmethod diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py index ef539934..baabe186 100644 --- a/api/app/schemas/tool_schema.py +++ b/api/app/schemas/tool_schema.py @@ -225,7 +225,7 @@ class CustomToolCreateRequest(BaseModel): class ParseSchemaRequest(BaseModel): """解析Schema请求""" - schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容") + schema_content: Optional[str] = Field(None, description="OpenAPI schema内容") schema_url: Optional[str] = Field(None, description="OpenAPI schema URL") diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 69d99c18..783df81a 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -461,7 +461,7 @@ class ToolService: parser = OpenAPISchemaParser() if config.get("schema_content"): - success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), "application/json") + success, schema, _ = parser.parse_from_content(config["schema_content"], "application/json") else: success, schema, _ = parser.parse_from_url(config["schema_url"]) @@ -515,7 +515,7 @@ class ToolService: parser = OpenAPISchemaParser() if config.get("schema_content"): - success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), + success, schema, _ = parser.parse_from_content(config["schema_content"], "application/json") else: success, schema, _ = parser.parse_from_url(config["schema_url"]) @@ -686,7 +686,7 @@ class ToolService: if await client.connect(): try: - tools = await client.list_tools() + # tools = await client.list_tools() await client.disconnect() # 更新连接状态 @@ -701,7 +701,8 @@ class ToolService: return { "success": True, "message": "MCP连接成功", - "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} + # "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} + "details": {"server_url": mcp_config.server_url} } except Exception as e: await client.disconnect() @@ -739,16 +740,15 @@ class ToolService: return {"success": False, "message": f"MCP测试异常: {str(e)}"} @staticmethod - async def parse_openapi_schema(schema_data: Dict[str, Any] = None, schema_url: str = None) -> Dict[str, Any]: + async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]: """解析OpenAPI schema获取接口信息""" try: from app.core.tools.custom.schema_parser import OpenAPISchemaParser parser = OpenAPISchemaParser() - - # 使用现有的解析器 + if schema_data: - success, schema, error = parser.parse_from_content(json.dumps(schema_data), "application/json") + success, schema, error = parser.parse_from_content(schema_data, "application/json") elif schema_url: success, schema, error = await parser.parse_from_url(schema_url) else: