From 7f823ee72ee5e07fb9710549dc2184e708ade5e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Mon, 29 Dec 2025 18:32:29 +0800 Subject: [PATCH 1/6] feat(tool system): The specific method for obtaining the tool and the parameters to be passed --- api/app/core/tools/__init__.py | 4 ++-- api/app/core/tools/base.py | 2 +- api/app/core/tools/builtin/__init__.py | 12 ++++++------ api/app/core/tools/builtin/baidu_search_tool.py | 2 +- api/app/core/tools/builtin/datetime_tool.py | 2 +- api/app/core/tools/builtin/json_tool.py | 2 +- api/app/core/tools/builtin/mineru_tool.py | 2 +- api/app/core/tools/builtin/textin_tool.py | 2 +- api/app/core/tools/custom/__init__.py | 6 +++--- api/app/core/tools/mcp/__init__.py | 6 +++--- api/app/core/tools/mcp/base.py | 1 - api/app/core/tools/mcp/service_manager.py | 2 +- 12 files changed, 21 insertions(+), 22 deletions(-) diff --git a/api/app/core/tools/__init__.py b/api/app/core/tools/__init__.py index 714dc851..9d9407ad 100644 --- a/api/app/core/tools/__init__.py +++ b/api/app/core/tools/__init__.py @@ -1,7 +1,7 @@ """工具管理核心模块""" -from .base import BaseTool, ToolResult, ToolParameter -from .langchain_adapter import LangchainAdapter +from app.core.tools.base import BaseTool, ToolResult, ToolParameter +from app.core.tools.langchain_adapter import LangchainAdapter # 可选导入,避免导入错误 try: diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py index c9771ef0..ec15c50f 100644 --- a/api/app/core/tools/base.py +++ b/api/app/core/tools/base.py @@ -193,7 +193,7 @@ class BaseTool(ABC): def to_langchain_tool(self): """转换为Langchain工具格式""" - from .langchain_adapter import LangchainAdapter + from app.core.tools.langchain_adapter import LangchainAdapter return LangchainAdapter.convert_tool(self) def __repr__(self): diff --git a/api/app/core/tools/builtin/__init__.py b/api/app/core/tools/builtin/__init__.py index 3813402c..7d2ea0ef 100644 --- a/api/app/core/tools/builtin/__init__.py +++ b/api/app/core/tools/builtin/__init__.py @@ -1,11 +1,11 @@ """内置工具模块""" -from .base import BuiltinTool -from .datetime_tool import DateTimeTool -from .json_tool import JsonTool -from .baidu_search_tool import BaiduSearchTool -from .mineru_tool import MinerUTool -from .textin_tool import TextInTool +from app.core.tools.builtin.base import BuiltinTool +from app.core.tools.builtin.datetime_tool import DateTimeTool +from app.core.tools.builtin.json_tool import JsonTool +from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool +from app.core.tools.builtin.mineru_tool import MinerUTool +from app.core.tools.builtin.textin_tool import TextInTool __all__ = [ "BuiltinTool", diff --git a/api/app/core/tools/builtin/baidu_search_tool.py b/api/app/core/tools/builtin/baidu_search_tool.py index fddd6eb7..e1f80f34 100644 --- a/api/app/core/tools/builtin/baidu_search_tool.py +++ b/api/app/core/tools/builtin/baidu_search_tool.py @@ -4,7 +4,7 @@ from typing import List, Dict, Any import aiohttp from app.core.tools.base import ToolParameter, ToolResult, ParameterType -from .base import BuiltinTool +from app.core.tools.builtin.base import BuiltinTool class BaiduSearchTool(BuiltinTool): diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 647914b2..9cad3579 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -5,7 +5,7 @@ from typing import List import pytz from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType -from .base import BuiltinTool +from app.core.tools.builtin.base import BuiltinTool class DateTimeTool(BuiltinTool): diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py index 62cd98d3..12f1e688 100644 --- a/api/app/core/tools/builtin/json_tool.py +++ b/api/app/core/tools/builtin/json_tool.py @@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET from xml.dom import minidom from app.core.tools.base import ToolParameter, ToolResult, ParameterType -from .base import BuiltinTool +from app.core.tools.builtin.base import BuiltinTool class JsonTool(BuiltinTool): diff --git a/api/app/core/tools/builtin/mineru_tool.py b/api/app/core/tools/builtin/mineru_tool.py index b2a544c0..c480d6d8 100644 --- a/api/app/core/tools/builtin/mineru_tool.py +++ b/api/app/core/tools/builtin/mineru_tool.py @@ -4,7 +4,7 @@ from typing import List, Dict, Any import aiohttp from app.core.tools.base import ToolParameter, ToolResult, ParameterType -from .base import BuiltinTool +from app.core.tools.builtin.base import BuiltinTool class MinerUTool(BuiltinTool): diff --git a/api/app/core/tools/builtin/textin_tool.py b/api/app/core/tools/builtin/textin_tool.py index e5218416..4ec32659 100644 --- a/api/app/core/tools/builtin/textin_tool.py +++ b/api/app/core/tools/builtin/textin_tool.py @@ -4,7 +4,7 @@ from typing import List, Dict, Any import aiohttp from app.core.tools.base import ToolParameter, ToolResult, ParameterType -from .base import BuiltinTool +from app.core.tools.builtin.base import BuiltinTool class TextInTool(BuiltinTool): diff --git a/api/app/core/tools/custom/__init__.py b/api/app/core/tools/custom/__init__.py index 87b0488a..d56265e7 100644 --- a/api/app/core/tools/custom/__init__.py +++ b/api/app/core/tools/custom/__init__.py @@ -1,8 +1,8 @@ """自定义工具模块""" -from .base import CustomTool -from .schema_parser import OpenAPISchemaParser -from .auth_manager import AuthManager +from app.core.tools.custom.base import CustomTool +from app.core.tools.custom.schema_parser import OpenAPISchemaParser +from app.core.tools.custom.auth_manager import AuthManager __all__ = [ "CustomTool", diff --git a/api/app/core/tools/mcp/__init__.py b/api/app/core/tools/mcp/__init__.py index faf13ceb..4c9519b3 100644 --- a/api/app/core/tools/mcp/__init__.py +++ b/api/app/core/tools/mcp/__init__.py @@ -1,8 +1,8 @@ """MCP工具模块""" -from .base import MCPTool -from .client import MCPClient, MCPConnectionPool -from .service_manager import MCPServiceManager +from app.core.tools.mcp.base import MCPTool +from app.core.tools.mcp.client import MCPClient, MCPConnectionPool +from app.core.tools.mcp.service_manager import MCPServiceManager __all__ = [ "MCPTool", diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py index ca77f528..3fa103ab 100644 --- a/api/app/core/tools/mcp/base.py +++ b/api/app/core/tools/mcp/base.py @@ -1,7 +1,6 @@ """MCP工具基类""" import time from typing import Dict, Any, List -import aiohttp from app.models.tool_model import ToolType from app.core.tools.base import BaseTool diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py index 51d01535..f7349201 100644 --- a/api/app/core/tools/mcp/service_manager.py +++ b/api/app/core/tools/mcp/service_manager.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.core.logging_config import get_business_logger -from .client import MCPClient, MCPConnectionPool +from app.core.tools.mcp.client import MCPClient, MCPConnectionPool logger = get_business_logger() From 8e893662f38e843ba833b121b827f8f4b6a9a840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Tue, 30 Dec 2025 10:00:37 +0800 Subject: [PATCH 2/6] feat(tool system): add mcp testing services --- api_key_mcp_server.py | 38 +++++++++++ basic_auth_mcp_server.py | 45 +++++++++++++ bearer_token_mcp_server.py | 40 ++++++++++++ mcp_base.py | 111 +++++++++++++++++++++++++++++++ simple_mcp_server.py | 130 +++++++++++++++++++++++++++++++++++++ 5 files changed, 364 insertions(+) create mode 100644 api_key_mcp_server.py create mode 100644 basic_auth_mcp_server.py create mode 100644 bearer_token_mcp_server.py create mode 100644 mcp_base.py create mode 100644 simple_mcp_server.py diff --git a/api_key_mcp_server.py b/api_key_mcp_server.py new file mode 100644 index 00000000..f611dc59 --- /dev/null +++ b/api_key_mcp_server.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""API Key认证MCP服务器""" + +from fastapi import FastAPI, HTTPException, Depends, Header +from typing import Optional +import uvicorn +from mcp_base import MCPRequest, handle_mcp_request, TOOLS + +app = FastAPI(title="API Key MCP Server", version="1.0.0") + +# API Key配置 +API_KEYS = {"test-api-key", "demo-key-123"} + +def verify_api_key(x_api_key: Optional[str] = Header(None)): + """验证API Key""" + if x_api_key and x_api_key in API_KEYS: + return True + raise HTTPException(status_code=401, detail="Invalid API Key") + +@app.get("/") +async def root(): + return {"name": "API Key MCP Server", "version": "1.0.0", "auth_type": "api_key"} + +@app.get("/health") +async def health(): + return {"status": "healthy", "tools": len(TOOLS), "auth_type": "api_key"} + +@app.post("/mcp") +async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_api_key)): + return await handle_mcp_request(request, "API Key MCP Server") + +if __name__ == "__main__": + print("启动API Key认证MCP服务器...") + print("访问 http://localhost:8004 查看服务状态") + print("MCP端点: http://localhost:8004/mcp") + print("认证方式: API Key (Header: X-API-Key)") + print("测试API Keys: test-api-key, demo-key-123") + uvicorn.run(app, host="0.0.0.0", port=8004) \ No newline at end of file diff --git a/basic_auth_mcp_server.py b/basic_auth_mcp_server.py new file mode 100644 index 00000000..11bb5595 --- /dev/null +++ b/basic_auth_mcp_server.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +"""Basic Auth认证MCP服务器""" + +from fastapi import FastAPI, HTTPException, Depends, Header +from typing import Optional +import uvicorn +import base64 +from mcp_base import MCPRequest, handle_mcp_request, TOOLS + +app = FastAPI(title="Basic Auth MCP Server", version="1.0.0") + +# Basic Auth配置 +BASIC_AUTH_USERS = {"admin": "password", "user": "secret"} + +def verify_basic_auth(authorization: Optional[str] = Header(None)): + """验证Basic Auth""" + if authorization and authorization.startswith("Basic "): + try: + credentials = base64.b64decode(authorization.split(" ")[1]).decode() + username, password = credentials.split(":", 1) + if username in BASIC_AUTH_USERS and BASIC_AUTH_USERS[username] == password: + return True + except: + pass + raise HTTPException(status_code=401, detail="Invalid Basic Auth") + +@app.get("/") +async def root(): + return {"name": "Basic Auth MCP Server", "version": "1.0.0", "auth_type": "basic_auth"} + +@app.get("/health") +async def health(): + return {"status": "healthy", "tools": len(TOOLS), "auth_type": "basic_auth"} + +@app.post("/mcp") +async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_basic_auth)): + return await handle_mcp_request(request, "Basic Auth MCP Server") + +if __name__ == "__main__": + print("启动Basic Auth认证MCP服务器...") + print("访问 http://localhost:8006 查看服务状态") + print("MCP端点: http://localhost:8006/mcp") + print("认证方式: Basic Auth (Header: Authorization: Basic )") + print("测试用户: admin:password, user:secret") + uvicorn.run(app, host="0.0.0.0", port=8006) \ No newline at end of file diff --git a/bearer_token_mcp_server.py b/bearer_token_mcp_server.py new file mode 100644 index 00000000..57d27f2f --- /dev/null +++ b/bearer_token_mcp_server.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +"""Bearer Token认证MCP服务器""" + +from fastapi import FastAPI, HTTPException, Depends, Header +from typing import Optional +import uvicorn +from mcp_base import MCPRequest, handle_mcp_request, TOOLS + +app = FastAPI(title="Bearer Token MCP Server", version="1.0.0") + +# Bearer Token配置 +BEARER_TOKENS = {"bearer-token-123", "demo-bearer-token"} + +def verify_bearer_token(authorization: Optional[str] = Header(None)): + """验证Bearer Token""" + if authorization and authorization.startswith("Bearer "): + token = authorization.split(" ")[1] + if token in BEARER_TOKENS: + return True + raise HTTPException(status_code=401, detail="Invalid Bearer Token") + +@app.get("/") +async def root(): + return {"name": "Bearer Token MCP Server", "version": "1.0.0", "auth_type": "bearer_token"} + +@app.get("/health") +async def health(): + return {"status": "healthy", "tools": len(TOOLS), "auth_type": "bearer_token"} + +@app.post("/mcp") +async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_bearer_token)): + return await handle_mcp_request(request, "Bearer Token MCP Server") + +if __name__ == "__main__": + print("启动Bearer Token认证MCP服务器...") + print("访问 http://localhost:8005 查看服务状态") + print("MCP端点: http://localhost:8005/mcp") + print("认证方式: Bearer Token (Header: Authorization: Bearer )") + print("测试Bearer Tokens: bearer-token-123, demo-bearer-token") + uvicorn.run(app, host="0.0.0.0", port=8005) \ No newline at end of file diff --git a/mcp_base.py b/mcp_base.py new file mode 100644 index 00000000..f571e2fa --- /dev/null +++ b/mcp_base.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +"""MCP服务器基础模块 - 共享的模型和处理逻辑""" + +from pydantic import BaseModel +from typing import Dict, Any + +class MCPRequest(BaseModel): + jsonrpc: str = "2.0" + id: str + method: str + params: Dict[str, Any] = {} + +class MCPResponse(BaseModel): + jsonrpc: str = "2.0" + id: str + result: Any = None + error: Dict[str, Any] = None + +# 工具定义 +TOOLS = [ + { + "name": "calculator", + "description": "简单计算器", + "inputSchema": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式"} + }, + "required": ["expression"] + } + }, + { + "name": "echo", + "description": "回显工具", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "要回显的消息"} + }, + "required": ["message"] + } + } +] + +async def handle_mcp_request(request: MCPRequest, server_name: str = "MCP Server"): + """处理MCP请求""" + try: + if request.method == "initialize": + return MCPResponse( + id=request.id, + result={ + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": True}}, + "serverInfo": {"name": server_name, "version": "1.0.0"} + } + ) + + elif request.method == "tools/list": + return MCPResponse( + id=request.id, + result={"tools": TOOLS} + ) + + elif request.method == "tools/call": + tool_name = request.params.get("name") + arguments = request.params.get("arguments", {}) + + if tool_name == "calculator": + try: + expression = arguments.get("expression", "") + result = eval(expression) + return MCPResponse( + id=request.id, + result={"content": [{"type": "text", "text": f"结果: {result}"}]} + ) + except Exception as e: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"计算错误: {str(e)}"} + ) + + elif tool_name == "echo": + message = arguments.get("message", "") + return MCPResponse( + id=request.id, + result={"content": [{"type": "text", "text": f"Echo: {message}"}]} + ) + + else: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"未知工具: {tool_name}"} + ) + + elif request.method == "ping": + return MCPResponse( + id=request.id, + result={"status": "pong"} + ) + + else: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"未知方法: {request.method}"} + ) + + except Exception as e: + return MCPResponse( + id=request.id, + error={"code": -1, "message": str(e)} + ) \ No newline at end of file diff --git a/simple_mcp_server.py b/simple_mcp_server.py new file mode 100644 index 00000000..fa299e37 --- /dev/null +++ b/simple_mcp_server.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +"""简化的MCP服务器 - 用于测试MCP工具集成""" + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import Dict, Any, List +import uvicorn + +app = FastAPI(title="Simple MCP Server", version="1.0.0") + +class MCPRequest(BaseModel): + jsonrpc: str = "2.0" + id: str + method: str + params: Dict[str, Any] = {} + +class MCPResponse(BaseModel): + jsonrpc: str = "2.0" + id: str + result: Any = None + error: Dict[str, Any] = None + +# 可用工具定义 +TOOLS = [ + { + "name": "calculator", + "description": "简单计算器", + "inputSchema": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式"} + }, + "required": ["expression"] + } + }, + { + "name": "echo", + "description": "回显工具", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "要回显的消息"} + }, + "required": ["message"] + } + } +] + +@app.get("/") +async def root(): + return {"name": "Simple MCP Server", "version": "1.0.0"} + +@app.get("/health") +async def health(): + return {"status": "healthy", "tools": len(TOOLS)} + +@app.post("/mcp") +async def mcp_handler(request: MCPRequest): + """处理MCP请求""" + try: + if request.method == "initialize": + return MCPResponse( + id=request.id, + result={ + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": True}}, + "serverInfo": {"name": "Simple MCP Server", "version": "1.0.0"} + } + ) + + elif request.method == "tools/list": + return MCPResponse( + id=request.id, + result={"tools": TOOLS} + ) + + elif request.method == "tools/call": + tool_name = request.params.get("name") + arguments = request.params.get("arguments", {}) + + if tool_name == "calculator": + try: + expression = arguments.get("expression", "") + result = eval(expression) # 注意:生产环境不要用eval + return MCPResponse( + id=request.id, + result={"content": [{"type": "text", "text": f"结果: {result}"}]} + ) + except Exception as e: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"计算错误: {str(e)}"} + ) + + elif tool_name == "echo": + message = arguments.get("message", "") + return MCPResponse( + id=request.id, + result={"content": [{"type": "text", "text": f"Echo: {message}"}]} + ) + + else: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"未知工具: {tool_name}"} + ) + + elif request.method == "ping": + return MCPResponse( + id=request.id, + result={"status": "pong"} + ) + + else: + return MCPResponse( + id=request.id, + error={"code": -1, "message": f"未知方法: {request.method}"} + ) + + except Exception as e: + return MCPResponse( + id=request.id, + error={"code": -1, "message": str(e)} + ) + +if __name__ == "__main__": + print("启动简化MCP服务器...") + print("访问 http://localhost:8002 查看服务状态") + print("MCP端点: http://localhost:8002/mcp") + uvicorn.run(app, host="0.0.0.0", port=8002) \ No newline at end of file From e6c35e5f5ad90e45d0ba69e79f2997176ab4963b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Tue, 30 Dec 2025 21:07:24 +0800 Subject: [PATCH 3/6] feat(tool system): add all methods for obtaining the tool --- api/app/controllers/tool_controller.py | 19 ++- api/app/core/tools/builtin/datetime_tool.py | 2 +- api/app/core/tools/builtin/json_tool.py | 3 +- api/app/core/tools/mcp/client.py | 4 +- api/app/repositories/tool_repository.py | 30 +++- api/app/services/tool_service.py | 159 ++++++++++++++++++++ 6 files changed, 208 insertions(+), 9 deletions(-) diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 79c87205..479686ef 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -60,6 +60,22 @@ async def list_tools( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/{tool_id}/methods", response_model=ApiResponse) +async def get_tool_methods( + tool_id: str, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """获取工具的所有方法""" + try: + methods = await service.get_tool_methods(tool_id, current_user.tenant_id) + if methods is None: + raise HTTPException(status_code=404, detail="工具不存在") + return success(data=methods, msg="获取工具方法成功") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/{tool_id}", response_model=ApiResponse) async def get_tool( tool_id: str, @@ -159,7 +175,8 @@ async def execute_tool( workspace_id=current_user.current_workspace_id, timeout=request.timeout ) - + if not result.success: + raise HTTPException(status_code=400, detail=result["error"]) return success( data={ "success": result.success, diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 9cad3579..7b6fa8ef 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"] + enum=["format", "convert_timezone", "timestamp_to_datetime", "now"] ), ToolParameter( name="input_value", diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py index 12f1e688..f22e9370 100644 --- a/api/app/core/tools/builtin/json_tool.py +++ b/api/app/core/tools/builtin/json_tool.py @@ -29,8 +29,7 @@ class JsonTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", - "extract", "insert", "replace", "delete", "parse"] + enum=["insert", "replace", "delete", "parse"] ), ToolParameter( name="input_data", diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 2e37f2b1..a1d2ecaa 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -204,7 +204,7 @@ class MCPClient: ) init_response = json.loads(response) - if "error" in init_response: + if init_response.get("error", None) is not None: raise MCPProtocolError(f"初始化失败: {init_response['error']}") return True @@ -325,7 +325,7 @@ class MCPClient: try: response = await self._send_request(request_data, timeout) - if "error" in response: + if response.get("error", None) is not None: error = response["error"] raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index dc78e761..3aa7b16e 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -1,10 +1,9 @@ """工具数据访问层""" import uuid -from typing import List, Optional, Dict, Any +from typing import List, Optional from sqlalchemy.orm import Session -from sqlalchemy import func, or_ +from sqlalchemy import func -from app.repositories.base_repository import BaseRepository from app.models.tool_model import ( ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolExecution, ToolType, ToolStatus @@ -14,6 +13,31 @@ from app.models.tool_model import ( class ToolRepository: """工具仓储类""" + @staticmethod + def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]: + """根据工作流ID获取tenant_id + + Args: + db: 数据库会话 + workflow_id: 工作流配置ID + + Returns: + tenant_id或None + """ + from app.models.app_model import App + from app.models.workflow_model import WorkflowConfig + from app.models.workspace_model import Workspace + + result = db.query(Workspace.tenant_id).join( + App, App.workspace_id == Workspace.id + ).join( + WorkflowConfig, WorkflowConfig.app_id == App.id + ).filter( + WorkflowConfig.id == workflow_id + ).first() + + return result[0] if result else None + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 783df81a..50cca957 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -297,6 +297,165 @@ class ToolService: self.db.commit() logger.info(f"租户 {tenant_id} 内置工具初始化完成") + async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]: + """获取工具的所有方法 + + Args: + tool_id: 工具ID + tenant_id: 租户ID + + Returns: + 方法列表或None + """ + config = self._get_tool_config(tool_id, tenant_id) + if not config: + return None + + try: + if config.tool_type == ToolType.BUILTIN.value: + return await self._get_builtin_tool_methods(config) + elif config.tool_type == ToolType.CUSTOM.value: + return await self._get_custom_tool_methods(config) + elif config.tool_type == ToolType.MCP.value: + return await self._get_mcp_tool_methods(config) + else: + return [] + + except Exception as e: + logger.error(f"获取工具方法失败: {tool_id}, {e}") + return [] + + async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取内置工具的方法""" + builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id) + if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS: + return [] + + # 获取工具实例 + tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + if not tool_instance: + return [] + + # 检查是否有operation参数 + operation_param = None + for param in tool_instance.parameters: + if param.name == "operation" and param.enum: + operation_param = param + break + + if operation_param: + # 有多个操作 + methods = [] + for operation in operation_param.enum: + methods.append({ + "method_id": f"{config.name}_{operation}", + "name": operation, + "description": f"{config.description} - {operation}", + "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + }) + return methods + else: + # 只有一个方法 + return [{ + "method_id": config.name, + "name": config.name, + "description": config.description, + "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + }] + + async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取自定义工具的方法""" + custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) + if not custom_config: + return [] + + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + parser = OpenAPISchemaParser() + + # 解析schema + if custom_config.schema_content: + success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json") + elif custom_config.schema_url: + success, schema, error = await parser.parse_from_url(custom_config.schema_url) + else: + return [] + + if not success: + return [] + + # 提取操作 + tool_info = parser.extract_tool_info(schema) + operations = tool_info.get("operations", {}) + + methods = [] + for operation_id, operation in operations.items(): + # 生成参数列表 + parameters = [] + + # 路径和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + parameters.append({ + "name": param_name, + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + "required": param_info.get("required", False), + "enum": param_info.get("enum"), + "default": param_info.get("default") + }) + + # 请求体参数 + request_body = operation.get("request_body") + if request_body: + schema_props = request_body.get("schema", {}).get("properties", {}) + required_props = request_body.get("schema", {}).get("required", []) + + for prop_name, prop_schema in schema_props.items(): + parameters.append({ + "name": prop_name, + "type": prop_schema.get("type", "string"), + "description": prop_schema.get("description", ""), + "required": prop_name in required_props, + "enum": prop_schema.get("enum"), + "default": prop_schema.get("default") + }) + + methods.append({ + "method_id": operation_id, + "name": operation.get("summary", operation_id), + "description": operation.get("description", ""), + "method": operation.get("method", "GET"), + "path": operation.get("path", "/"), + "parameters": parameters + }) + + return methods + + except Exception as e: + logger.error(f"解析自定义工具schema失败: {e}") + return [] + + async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取MCP工具的方法""" + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if not mcp_config: + return [] + + available_tools = mcp_config.available_tools or [] + if not available_tools: + return [] + + methods = [] + for tool_name in available_tools: + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": f"MCP工具: {tool_name}", + "parameters": [] # MCP工具参数需要动态获取 + }) + + return methods + def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]: """获取工具统计信息""" try: From 0475d804725c41f91f8508f968cf9dbefd9b9e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Tue, 30 Dec 2025 21:08:05 +0800 Subject: [PATCH 4/6] feat(tool system): add workflow tool nodes --- api/app/core/workflow/nodes/__init__.py | 6 +- api/app/core/workflow/nodes/configs.py | 2 + api/app/core/workflow/nodes/node_factory.py | 5 +- api/app/core/workflow/nodes/tool/__init__.py | 4 ++ api/app/core/workflow/nodes/tool/config.py | 9 +++ api/app/core/workflow/nodes/tool/node.py | 72 ++++++++++++++++++++ 6 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 api/app/core/workflow/nodes/tool/__init__.py create mode 100644 api/app/core/workflow/nodes/tool/config.py create mode 100644 api/app/core/workflow/nodes/tool/node.py diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 174fa877..926f86e4 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode +from app.core.workflow.nodes.question_classifier import QuestionClassifierNode +from app.core.workflow.nodes.tool import ToolNode __all__ = [ "BaseNode", @@ -33,5 +35,7 @@ __all__ = [ "AssignerNode", "HttpRequestNode", "JinjaRenderNode", - "ParameterExtractorNode" + "ParameterExtractorNode", + "QuestionClassifierNode", + "ToolNode" ] diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 2ba23d4c..6e9c2c51 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig __all__ = [ @@ -45,4 +46,5 @@ __all__ = [ "LoopNodeConfig", "IterationNodeConfig", "QuestionClassifierNodeConfig" + "ToolNodeConfig" ] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index ed26533d..df565efe 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode +from app.core.workflow.nodes.tool import ToolNode logger = logging.getLogger(__name__) @@ -44,7 +45,8 @@ WorkflowNode = Union[ CycleGraphNode, BreakNode, ParameterExtractorNode, - QuestionClassifierNode + QuestionClassifierNode, + ToolNode ] @@ -72,6 +74,7 @@ class NodeFactory: NodeType.LOOP: CycleGraphNode, NodeType.ITERATION: CycleGraphNode, NodeType.BREAK: BreakNode, + NodeType.TOOL: ToolNode, } @classmethod diff --git a/api/app/core/workflow/nodes/tool/__init__.py b/api/app/core/workflow/nodes/tool/__init__.py new file mode 100644 index 00000000..8392f05c --- /dev/null +++ b/api/app/core/workflow/nodes/tool/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.tool.config import ToolNodeConfig +from app.core.workflow.nodes.tool.node import ToolNode + +__all__ = ["ToolNode", "ToolNodeConfig"] \ No newline at end of file diff --git a/api/app/core/workflow/nodes/tool/config.py b/api/app/core/workflow/nodes/tool/config.py new file mode 100644 index 00000000..487efae2 --- /dev/null +++ b/api/app/core/workflow/nodes/tool/config.py @@ -0,0 +1,9 @@ +from pydantic import Field +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class ToolNodeConfig(BaseNodeConfig): + """工具节点配置""" + + tool_id: str = Field(..., description="工具ID") + tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py new file mode 100644 index 00000000..993a3804 --- /dev/null +++ b/api/app/core/workflow/nodes/tool/node.py @@ -0,0 +1,72 @@ +import logging +import uuid +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.tool.config import ToolNodeConfig +from app.services.tool_service import ToolService +from app.db import get_db_read + +logger = logging.getLogger(__name__) + + +class ToolNode(BaseNode): + """工具节点""" + + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = ToolNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> dict[str, Any]: + """执行工具""" + # 获取租户ID和用户ID + tenant_id = self.get_variable("sys.tenant_id", state) + user_id = self.get_variable("sys.user_id", state) + + # 如果没有租户ID,尝试从工作流ID获取 + if not tenant_id: + workflow_id = self.get_variable("sys.workflow_id", state) + if workflow_id: + from app.repositories.tool_repository import ToolRepository + with get_db_read() as db: + tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id) + + if not tenant_id: + tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097") + # logger.error(f"节点 {self.node_id} 缺少租户ID") + # return {"error": "缺少租户ID"} + + # 渲染工具参数 + rendered_parameters = {} + for param_name, param_template in self.typed_config.tool_parameters.items(): + rendered_value = self._render_template(param_template, state) + rendered_parameters[param_name] = rendered_value + + logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") + print(self.typed_config.tool_id) + + # 执行工具 + with get_db_read() as db: + tool_service = ToolService(db) + result = await tool_service.execute_tool( + tool_id=self.typed_config.tool_id, + parameters=rendered_parameters, + tenant_id=tenant_id, + user_id=user_id + ) + print(result) + if result.success: + logger.info(f"节点 {self.node_id} 工具执行成功") + return { + "success": True, + "data": result.data, + "execution_time": result.execution_time + } + else: + logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") + return { + "success": False, + "error": result.error, + "error_code": result.error_code, + "execution_time": result.execution_time + } \ No newline at end of file From a8c5368d49422dea380ce99dcf81911d743fdd80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Sun, 4 Jan 2026 15:36:24 +0800 Subject: [PATCH 5/6] feat(home page): add statistical interface --- api/app/controllers/__init__.py | 2 + api/app/controllers/home_page_controller.py | 29 ++++ api/app/core/api_key_utils.py | 2 +- api/app/repositories/home_page_repository.py | 137 +++++++++++++++++++ api/app/schemas/home_page_schema.py | 32 +++++ api/app/services/home_page_service.py | 67 +++++++++ 6 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 api/app/controllers/home_page_controller.py create mode 100644 api/app/repositories/home_page_repository.py create mode 100644 api/app/schemas/home_page_schema.py create mode 100644 api/app/services/home_page_service.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 13e66ea7..2cddfb30 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -33,6 +33,7 @@ from . import ( emotion_config_controller, prompt_optimizer_controller, tool_controller, + home_page_controller, ) from . import user_memory_controllers @@ -70,5 +71,6 @@ manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(tool_controller.router) +manager_router.include_router(home_page_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/home_page_controller.py b/api/app/controllers/home_page_controller.py new file mode 100644 index 00000000..6665eec1 --- /dev/null +++ b/api/app/controllers/home_page_controller.py @@ -0,0 +1,29 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse +from app.services.home_page_service import HomePageService + +router = APIRouter(prefix="/home-page", tags=["Home Page"]) + +@router.get("/statistics", response_model=ApiResponse) +def get_home_statistics( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取首页统计数据""" + statistics = HomePageService.get_home_statistics(db, current_user.tenant_id) + return success(data=statistics, msg="统计数据获取成功") + +@router.get("/workspaces", response_model=ApiResponse) +def get_workspace_list( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取工作空间列表""" + workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id) + return success(data=workspace_list, msg="工作空间列表获取成功") \ No newline at end of file diff --git a/api/app/core/api_key_utils.py b/api/app/core/api_key_utils.py index 877ddd01..fb6b9552 100644 --- a/api/app/core/api_key_utils.py +++ b/api/app/core/api_key_utils.py @@ -3,7 +3,7 @@ import secrets from typing import Optional, Union from datetime import datetime -from app.schemas.api_key_schema import ApiKeyType +from app.models.api_key_model import ApiKeyType from fastapi import Response from fastapi.responses import JSONResponse diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py new file mode 100644 index 00000000..e37f1f00 --- /dev/null +++ b/api/app/repositories/home_page_repository.py @@ -0,0 +1,137 @@ +from datetime import datetime, timedelta +from sqlalchemy.orm import Session +from sqlalchemy import func +from uuid import UUID +from typing import Dict + +from app.models.end_user_model import EndUser +from app.models.user_model import User +from app.models.workspace_model import Workspace, WorkspaceMember +from app.models.models_model import ModelConfig +from app.models.app_model import App + +class HomePageRepository: + + @staticmethod + def get_model_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]: + """获取模型统计数据""" + total_models = db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_active == True + ).count() + + new_models_this_month = db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_active == True, + ModelConfig.created_at >= month_start + ).count() + + return total_models, new_models_this_month + + @staticmethod + def get_workspace_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]: + """获取工作空间统计数据""" + active_workspaces = db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active == True + ).count() + + new_workspaces_this_month = db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active == True, + Workspace.created_at >= month_start + ).count() + + return active_workspaces, new_workspaces_this_month + + @staticmethod + def get_user_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]: + """获取用户统计数据""" + workspace_ids = db.query(Workspace.id).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active == True + ).subquery() + + total_users = db.query(EndUser).join( + App, + EndUser.app_id == App.id + ).filter( + App.workspace_id.in_(workspace_ids), + App.is_active == True, + App.status == "active" + ).count() + + new_users_this_month = db.query(EndUser).join( + App, + EndUser.app_id == App.id + ).filter( + App.workspace_id.in_(workspace_ids), + App.is_active == True, + App.status == "active", + EndUser.created_at >= month_start + ).count() + + return total_users, new_users_this_month + + @staticmethod + def get_app_statistics(db: Session, tenant_id: UUID, week_start: datetime) -> tuple[int, int]: + """获取应用统计数据""" + workspace_ids = db.query(Workspace.id).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active == True + ).subquery() + + running_apps = db.query(App).filter( + App.workspace_id.in_(workspace_ids), + App.is_active == True, + App.status == "active" + ).count() + + new_apps_this_week = db.query(App).filter( + App.workspace_id.in_(workspace_ids), + App.is_active == True, + App.status == "active", + App.created_at >= week_start + ).count() + + return running_apps, new_apps_this_week + + @staticmethod + def get_workspaces_with_counts(db: Session, tenant_id: UUID) -> tuple[list[Workspace], Dict[UUID, int], Dict[UUID, int]]: + """批量获取工作空间及其统计数据""" + # 获取工作空间列表 + workspaces = db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active == True + ).all() + + workspace_ids = [ws.id for ws in workspaces] + + # 批量获取应用数量 + app_counts = db.query( + App.workspace_id, + func.count(App.id).label('count') + ).filter( + App.workspace_id.in_(workspace_ids), + App.is_active, + App.status == "active" + ).group_by(App.workspace_id).all() + + app_count_dict = {workspace_id: count for workspace_id, count in app_counts} + + # 批量获取用户数量 + user_counts = db.query( + App.workspace_id, + func.count(EndUser.id).label('count') + ).join( + EndUser, + EndUser.app_id == App.id + ).filter( + App.workspace_id.in_(workspace_ids), + App.is_active, + App.status == "active" + ).group_by(App.workspace_id).all() + + user_count_dict = {workspace_id: count for workspace_id, count in user_counts} + + return workspaces, app_count_dict, user_count_dict \ No newline at end of file diff --git a/api/app/schemas/home_page_schema.py b/api/app/schemas/home_page_schema.py new file mode 100644 index 00000000..de223e17 --- /dev/null +++ b/api/app/schemas/home_page_schema.py @@ -0,0 +1,32 @@ +from datetime import datetime +from pydantic import BaseModel, field_serializer +from typing import Optional + +from app.core.api_key_utils import datetime_to_timestamp + + +class HomeStatistics(BaseModel): + """首页统计数据""" + total_models: int + new_models_this_month: int + active_workspaces: int + new_workspaces_this_month: int + total_users: int + new_users_this_month: int + running_apps: int + new_apps_this_week: int + +class WorkspaceInfo(BaseModel): + """工作空间信息""" + id: str + name: str + icon: Optional[str] + description: Optional[str] + app_count: int + user_count: int + created_at: datetime + + @field_serializer('created_at') + @classmethod + def serialize_datetime(cls, v: datetime) -> Optional[int]: + return datetime_to_timestamp(v) \ No newline at end of file diff --git a/api/app/services/home_page_service.py b/api/app/services/home_page_service.py new file mode 100644 index 00000000..909da25f --- /dev/null +++ b/api/app/services/home_page_service.py @@ -0,0 +1,67 @@ +from datetime import datetime, timedelta +from sqlalchemy.orm import Session +from uuid import UUID + +from app.repositories.home_page_repository import HomePageRepository +from app.schemas.home_page_schema import HomeStatistics, WorkspaceInfo + +class HomePageService: + + @staticmethod + def get_home_statistics(db: Session, tenant_id: UUID) -> HomeStatistics: + """获取首页统计数据""" + # 计算时间范围 + now = datetime.now() + month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + week_start = now - timedelta(days=now.weekday()) + week_start = week_start.replace(hour=0, minute=0, second=0, microsecond=0) + + # 获取各项统计数据 + total_models, new_models_this_month = HomePageRepository.get_model_statistics( + db, tenant_id, month_start + ) + + active_workspaces, new_workspaces_this_month = HomePageRepository.get_workspace_statistics( + db, tenant_id, month_start + ) + + total_users, new_users_this_month = HomePageRepository.get_user_statistics( + db, tenant_id, month_start + ) + + running_apps, new_apps_this_week = HomePageRepository.get_app_statistics( + db, tenant_id, week_start + ) + + return HomeStatistics( + total_models=total_models, + new_models_this_month=new_models_this_month, + active_workspaces=active_workspaces, + new_workspaces_this_month=new_workspaces_this_month, + total_users=total_users, + new_users_this_month=new_users_this_month, + running_apps=running_apps, + new_apps_this_week=new_apps_this_week + ) + + @staticmethod + def get_workspace_list(db: Session, tenant_id: UUID) -> list[WorkspaceInfo]: + """获取工作空间列表(优化版本)""" + workspaces, app_count_dict, user_count_dict= HomePageRepository.get_workspaces_with_counts( + db, tenant_id + ) + + workspace_list = [] + for workspace in workspaces: + workspace_info = WorkspaceInfo( + id=str(workspace.id), + name=workspace.name, + icon=workspace.icon, + description=workspace.description, + app_count=app_count_dict.get(workspace.id, 0), + user_count=user_count_dict.get(workspace.id, 0), + created_at=workspace.created_at + ) + workspace_list.append(workspace_info) + + return workspace_list \ No newline at end of file From c0b29dd9384437dd905c98141e4afe8dc8b69d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Sun, 4 Jan 2026 19:06:51 +0800 Subject: [PATCH 6/6] feat(workflow_node): question classifier node optimization --- api/app/core/workflow/executor.py | 8 +- .../nodes/question_classifier/config.py | 1 - .../nodes/question_classifier/node.py | 107 +++++++++++------- 3 files changed, 70 insertions(+), 46 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 0d0879d7..fe75eace 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -219,17 +219,13 @@ class WorkflowExecutor: # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]: # Find all edges whose source is the current node related_edge = [edge for edge in self.edges if edge.get("source") == node_id] # Iterate over each branch - for idx in range(branch_number): + for idx in range(len(related_edge)): # Generate a condition expression for each edge # Used later to determine which branch to take based on the node's output # Assumes node output `node..output` matches the edge's label diff --git a/api/app/core/workflow/nodes/question_classifier/config.py b/api/app/core/workflow/nodes/question_classifier/config.py index f3b2cc20..998e2fb4 100644 --- a/api/app/core/workflow/nodes/question_classifier/config.py +++ b/api/app/core/workflow/nodes/question_classifier/config.py @@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig): default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", description="用户提示词模板" ) - output_variable: str = Field(default="class_name", description="输出分类结果的变量名") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index bd3c8752..67f53801 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -12,6 +12,9 @@ from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) +DEFAULT_CASE_PREFIX = "CASE" +DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1" + class QuestionClassifierNode(BaseNode): """问题分类器节点""" @@ -19,6 +22,7 @@ class QuestionClassifierNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config = QuestionClassifierNodeConfig(**self.config) + self.category_to_case_map = self._build_category_case_map() def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -47,48 +51,73 @@ class QuestionClassifierNode(BaseNode): ), type=ModelType(model_type) ) + + def _build_category_case_map(self) -> dict[str, str]: + """ + 预构建 分类名称 -> CASE标识 的映射字典 + 示例:{"产品咨询": "CASE1", "售后问题": "CASE2"} + """ + category_map = {} + categories = self.typed_config.categories or [] + for idx, class_item in enumerate(categories, start=1): + category_name = class_item.class_name.strip() + case_tag = f"{DEFAULT_CASE_PREFIX}{idx}" + category_map[category_name] = case_tag + return category_map - async def execute(self, state: WorkflowState) -> dict[str, Any]: + async def execute(self, state: WorkflowState) -> str: """执行问题分类""" question = self.typed_config.input_variable - - supplement_prompt = "" - if self.typed_config.user_supplement_prompt is not None: - supplement_prompt = self.typed_config.user_supplement_prompt - - category_names = [class_item.class_name for class_item in self.typed_config.categories] + supplement_prompt = self.typed_config.user_supplement_prompt or "" + categories = self.typed_config.categories or [] + category_names = [class_item.class_name.strip() for class_item in categories] + category_count = len(category_names) if not question: - logger.warning(f"节点 {self.node_id} 未获取到输入问题") - return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} - - llm = self._get_llm_instance() - - # 渲染用户提示词模板,支持工作流变量 - user_prompt = self._render_template( - self.typed_config.user_prompt.format( - question=question, - categories=", ".join(category_names), - supplement_prompt=supplement_prompt - ), - state - ) - - messages = [ - ("system", self.typed_config.system_prompt), - ("user", user_prompt), - ] - - response = await llm.ainvoke(messages) - result = response.content.strip() - - if result in category_names: - category = result - else: - logger.warning(f"LLM返回了未知类别: {result}") - category = category_names[0] if category_names else "unknown" + logger.warning( + f"节点 {self.node_id} 未获取到输入问题,使用默认分支" + f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" + ) + # 若分类列表为空,返回默认unknown分支,否则返回CASE1 + return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" - log_supplement = supplement_prompt if supplement_prompt else "无" - logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - - return {self.typed_config.output_variable: category} \ No newline at end of file + try: + llm = self._get_llm_instance() + + # 渲染用户提示词模板,支持工作流变量 + user_prompt = self._render_template( + self.typed_config.user_prompt.format( + question=question, + categories=", ".join(category_names), + supplement_prompt=supplement_prompt + ), + state + ) + + messages = [ + ("system", self.typed_config.system_prompt), + ("user", user_prompt), + ] + + response = await llm.ainvoke(messages) + result = response.content.strip() + + if result in category_names: + category = result + else: + logger.warning(f"LLM返回了未知类别: {result}") + category = category_names[0] if category_names else "unknown" + + log_supplement = supplement_prompt if supplement_prompt else "无" + logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") + + return f"CASE{category_names.index(category) + 1}" + except Exception as e: + logger.error( + f"节点 {self.node_id} 分类执行异常:{str(e)}", + exc_info=True # 打印堆栈信息,便于调试 + ) + # 异常时返回默认分支,保证工作流容错性 + if category_count > 0: + return DEFAULT_EMPTY_QUESTION_CASE + return "unknown"