Merge #102 into develop from feature/20251219_xjn

feat(workflow_node): question classifier node optimization

* feature/20251219_xjn: (9 commits)
  feat(tool system): The specific method for obtaining the tool and the parameters to be passed
  feat(tool system): add mcp testing services
  Merge branch 'refs/heads/develop' into feature/20251219_xjn
  feat(tool system): add all methods for obtaining the tool
  feat(tool system): add workflow tool nodes
  Merge branch 'refs/heads/develop' into feature/20251219_xjn
  feat(home page): add statistical interface
  Merge branch 'refs/heads/develop' into feature/20251219_xjn
  feat(workflow_node): question classifier node optimization

Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/102
This commit is contained in:
朱文辉
2026-01-05 10:46:53 +08:00
36 changed files with 1027 additions and 80 deletions

View File

@@ -33,6 +33,7 @@ from . import (
emotion_config_controller, emotion_config_controller,
prompt_optimizer_controller, prompt_optimizer_controller,
tool_controller, tool_controller,
home_page_controller,
) )
from . import user_memory_controllers from . import user_memory_controllers
@@ -70,5 +71,6 @@ manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router) manager_router.include_router(tool_controller.router)
manager_router.include_router(home_page_controller.router)
__all__ = ["manager_router"] __all__ = ["manager_router"]

View File

@@ -0,0 +1,29 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.home_page_service import HomePageService
router = APIRouter(prefix="/home-page", tags=["Home Page"])
@router.get("/statistics", response_model=ApiResponse)
def get_home_statistics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取首页统计数据"""
statistics = HomePageService.get_home_statistics(db, current_user.tenant_id)
return success(data=statistics, msg="统计数据获取成功")
@router.get("/workspaces", response_model=ApiResponse)
def get_workspace_list(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")

View File

@@ -60,6 +60,22 @@ async def list_tools(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}/methods", response_model=ApiResponse)
async def get_tool_methods(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具的所有方法"""
try:
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
if methods is None:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=methods, msg="获取工具方法成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}", response_model=ApiResponse) @router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool( async def get_tool(
tool_id: str, tool_id: str,
@@ -159,7 +175,8 @@ async def execute_tool(
workspace_id=current_user.current_workspace_id, workspace_id=current_user.current_workspace_id,
timeout=request.timeout timeout=request.timeout
) )
if not result.success:
raise HTTPException(status_code=400, detail=result["error"])
return success( return success(
data={ data={
"success": result.success, "success": result.success,

View File

@@ -3,7 +3,7 @@ import secrets
from typing import Optional, Union from typing import Optional, Union
from datetime import datetime from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType from app.models.api_key_model import ApiKeyType
from fastapi import Response from fastapi import Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse

View File

@@ -1,7 +1,7 @@
"""工具管理核心模块""" """工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from .langchain_adapter import LangchainAdapter from app.core.tools.langchain_adapter import LangchainAdapter
# 可选导入,避免导入错误 # 可选导入,避免导入错误
try: try:

View File

@@ -193,7 +193,7 @@ class BaseTool(ABC):
def to_langchain_tool(self): def to_langchain_tool(self):
"""转换为Langchain工具格式""" """转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self) return LangchainAdapter.convert_tool(self)
def __repr__(self): def __repr__(self):

View File

@@ -1,11 +1,11 @@
"""内置工具模块""" """内置工具模块"""
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
from .datetime_tool import DateTimeTool from app.core.tools.builtin.datetime_tool import DateTimeTool
from .json_tool import JsonTool from app.core.tools.builtin.json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool from app.core.tools.builtin.mineru_tool import MinerUTool
from .textin_tool import TextInTool from app.core.tools.builtin.textin_tool import TextInTool
__all__ = [ __all__ = [
"BuiltinTool", "BuiltinTool",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
class BaiduSearchTool(BuiltinTool): class BaiduSearchTool(BuiltinTool):

View File

@@ -5,7 +5,7 @@ from typing import List
import pytz import pytz
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
class DateTimeTool(BuiltinTool): class DateTimeTool(BuiltinTool):
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING, type=ParameterType.STRING,
description="操作类型", description="操作类型",
required=True, required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"] enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
), ),
ToolParameter( ToolParameter(
name="input_value", name="input_value",

View File

@@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET
from xml.dom import minidom from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
class JsonTool(BuiltinTool): class JsonTool(BuiltinTool):
@@ -29,8 +29,7 @@ class JsonTool(BuiltinTool):
type=ParameterType.STRING, type=ParameterType.STRING,
description="操作类型", description="操作类型",
required=True, required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", enum=["insert", "replace", "delete", "parse"]
"extract", "insert", "replace", "delete", "parse"]
), ),
ToolParameter( ToolParameter(
name="input_data", name="input_data",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
class MinerUTool(BuiltinTool): class MinerUTool(BuiltinTool):

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool from app.core.tools.builtin.base import BuiltinTool
class TextInTool(BuiltinTool): class TextInTool(BuiltinTool):

View File

@@ -1,8 +1,8 @@
"""自定义工具模块""" """自定义工具模块"""
from .base import CustomTool from app.core.tools.custom.base import CustomTool
from .schema_parser import OpenAPISchemaParser from app.core.tools.custom.schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager from app.core.tools.custom.auth_manager import AuthManager
__all__ = [ __all__ = [
"CustomTool", "CustomTool",

View File

@@ -1,8 +1,8 @@
"""MCP工具模块""" """MCP工具模块"""
from .base import MCPTool from app.core.tools.mcp.base import MCPTool
from .client import MCPClient, MCPConnectionPool from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager from app.core.tools.mcp.service_manager import MCPServiceManager
__all__ = [ __all__ = [
"MCPTool", "MCPTool",

View File

@@ -1,7 +1,6 @@
"""MCP工具基类""" """MCP工具基类"""
import time import time
from typing import Dict, Any, List from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool from app.core.tools.base import BaseTool

View File

@@ -204,7 +204,7 @@ class MCPClient:
) )
init_response = json.loads(response) init_response = json.loads(response)
if "error" in init_response: if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}") raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True return True
@@ -325,7 +325,7 @@ class MCPClient:
try: try:
response = await self._send_request(request_data, timeout) response = await self._send_request(request_data, timeout)
if "error" in response: if response.get("error", None) is not None:
error = response["error"] error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
logger = get_business_logger() logger = get_business_logger()

View File

@@ -219,17 +219,13 @@ class WorkflowExecutor:
# 创建节点实例(现在 start 和 end 也会被创建) # 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node # Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id] related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch # Iterate over each branch
for idx in range(branch_number): for idx in range(len(related_edge)):
# Generate a condition expression for each edge # Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output # Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label # Assumes node output `node.<node_id>.output` matches the edge's label

View File

@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
__all__ = [ __all__ = [
"BaseNode", "BaseNode",
@@ -33,5 +35,7 @@ __all__ = [
"AssignerNode", "AssignerNode",
"HttpRequestNode", "HttpRequestNode",
"JinjaRenderNode", "JinjaRenderNode",
"ParameterExtractorNode" "ParameterExtractorNode",
"QuestionClassifierNode",
"ToolNode"
] ]

View File

@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [ __all__ = [
@@ -45,4 +46,5 @@ __all__ = [
"LoopNodeConfig", "LoopNodeConfig",
"IterationNodeConfig", "IterationNodeConfig",
"QuestionClassifierNodeConfig" "QuestionClassifierNodeConfig"
"ToolNodeConfig"
] ]

View File

@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,7 +45,8 @@ WorkflowNode = Union[
CycleGraphNode, CycleGraphNode,
BreakNode, BreakNode,
ParameterExtractorNode, ParameterExtractorNode,
QuestionClassifierNode QuestionClassifierNode,
ToolNode
] ]
@@ -72,6 +74,7 @@ class NodeFactory:
NodeType.LOOP: CycleGraphNode, NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode, NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode, NodeType.BREAK: BreakNode,
NodeType.TOOL: ToolNode,
} }
@classmethod @classmethod

View File

@@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig):
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
description="用户提示词模板" description="用户提示词模板"
) )
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")

View File

@@ -12,6 +12,9 @@ from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_CASE_PREFIX = "CASE"
DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
@@ -19,6 +22,7 @@ class QuestionClassifierNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config) self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
def _get_llm_instance(self) -> RedBearLLM: def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例""" """获取LLM实例"""
@@ -48,47 +52,72 @@ class QuestionClassifierNode(BaseNode):
type=ModelType(model_type) type=ModelType(model_type)
) )
async def execute(self, state: WorkflowState) -> dict[str, Any]: def _build_category_case_map(self) -> dict[str, str]:
"""
预构建 分类名称 -> CASE标识 的映射字典
示例:{"产品咨询": "CASE1", "售后问题": "CASE2"}
"""
category_map = {}
categories = self.typed_config.categories or []
for idx, class_item in enumerate(categories, start=1):
category_name = class_item.class_name.strip()
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
"""执行问题分类""" """执行问题分类"""
question = self.typed_config.input_variable question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
supplement_prompt = "" categories = self.typed_config.categories or []
if self.typed_config.user_supplement_prompt is not None: category_names = [class_item.class_name.strip() for class_item in categories]
supplement_prompt = self.typed_config.user_supplement_prompt category_count = len(category_names)
category_names = [class_item.class_name for class_item in self.typed_config.categories]
if not question: if not question:
logger.warning(f"节点 {self.node_id} 未获取到输入问题") logger.warning(
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
llm = self._get_llm_instance() try:
llm = self._get_llm_instance()
# 渲染用户提示词模板,支持工作流变量 # 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template( user_prompt = self._render_template(
self.typed_config.user_prompt.format( self.typed_config.user_prompt.format(
question=question, question=question,
categories=", ".join(category_names), categories=", ".join(category_names),
supplement_prompt=supplement_prompt supplement_prompt=supplement_prompt
), ),
state state
) )
messages = [ messages = [
("system", self.typed_config.system_prompt), ("system", self.typed_config.system_prompt),
("user", user_prompt), ("user", user_prompt),
] ]
response = await llm.ainvoke(messages) response = await llm.ainvoke(messages)
result = response.content.strip() result = response.content.strip()
if result in category_names: if result in category_names:
category = result category = result
else: else:
logger.warning(f"LLM返回了未知类别: {result}") logger.warning(f"LLM返回了未知类别: {result}")
category = category_names[0] if category_names else "unknown" category = category_names[0] if category_names else "unknown"
log_supplement = supplement_prompt if supplement_prompt else "" log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return {self.typed_config.output_variable: category} return f"CASE{category_names.index(category) + 1}"
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
exc_info=True # 打印堆栈信息,便于调试
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -0,0 +1,9 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -0,0 +1,72 @@
import logging
import uuid
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.services.tool_service import ToolService
from app.db import get_db_read
logger = logging.getLogger(__name__)
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
result = await tool_service.execute_tool(
tool_id=self.typed_config.tool_id,
parameters=rendered_parameters,
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
"success": True,
"data": result.data,
"execution_time": result.execution_time
}
else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}

View File

@@ -0,0 +1,137 @@
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from uuid import UUID
from typing import Dict
from app.models.end_user_model import EndUser
from app.models.user_model import User
from app.models.workspace_model import Workspace, WorkspaceMember
from app.models.models_model import ModelConfig
from app.models.app_model import App
class HomePageRepository:
@staticmethod
def get_model_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取模型统计数据"""
total_models = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True
).count()
new_models_this_month = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True,
ModelConfig.created_at >= month_start
).count()
return total_models, new_models_this_month
@staticmethod
def get_workspace_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取工作空间统计数据"""
active_workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).count()
new_workspaces_this_month = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True,
Workspace.created_at >= month_start
).count()
return active_workspaces, new_workspaces_this_month
@staticmethod
def get_user_statistics(db: Session, tenant_id: UUID, month_start: datetime) -> tuple[int, int]:
"""获取用户统计数据"""
workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).subquery()
total_users = db.query(EndUser).join(
App,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active"
).count()
new_users_this_month = db.query(EndUser).join(
App,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active",
EndUser.created_at >= month_start
).count()
return total_users, new_users_this_month
@staticmethod
def get_app_statistics(db: Session, tenant_id: UUID, week_start: datetime) -> tuple[int, int]:
"""获取应用统计数据"""
workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).subquery()
running_apps = db.query(App).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active"
).count()
new_apps_this_week = db.query(App).filter(
App.workspace_id.in_(workspace_ids),
App.is_active == True,
App.status == "active",
App.created_at >= week_start
).count()
return running_apps, new_apps_this_week
@staticmethod
def get_workspaces_with_counts(db: Session, tenant_id: UUID) -> tuple[list[Workspace], Dict[UUID, int], Dict[UUID, int]]:
"""批量获取工作空间及其统计数据"""
# 获取工作空间列表
workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active == True
).all()
workspace_ids = [ws.id for ws in workspaces]
# 批量获取应用数量
app_counts = db.query(
App.workspace_id,
func.count(App.id).label('count')
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active,
App.status == "active"
).group_by(App.workspace_id).all()
app_count_dict = {workspace_id: count for workspace_id, count in app_counts}
# 批量获取用户数量
user_counts = db.query(
App.workspace_id,
func.count(EndUser.id).label('count')
).join(
EndUser,
EndUser.app_id == App.id
).filter(
App.workspace_id.in_(workspace_ids),
App.is_active,
App.status == "active"
).group_by(App.workspace_id).all()
user_count_dict = {workspace_id: count for workspace_id, count in user_counts}
return workspaces, app_count_dict, user_count_dict

View File

@@ -1,10 +1,9 @@
"""工具数据访问层""" """工具数据访问层"""
import uuid import uuid
from typing import List, Optional, Dict, Any from typing import List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import func, or_ from sqlalchemy import func
from app.repositories.base_repository import BaseRepository
from app.models.tool_model import ( from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus ToolExecution, ToolType, ToolStatus
@@ -14,6 +13,31 @@ from app.models.tool_model import (
class ToolRepository: class ToolRepository:
"""工具仓储类""" """工具仓储类"""
@staticmethod
def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]:
"""根据工作流ID获取tenant_id
Args:
db: 数据库会话
workflow_id: 工作流配置ID
Returns:
tenant_id或None
"""
from app.models.app_model import App
from app.models.workflow_model import WorkflowConfig
from app.models.workspace_model import Workspace
result = db.query(Workspace.tenant_id).join(
App, App.workspace_id == Workspace.id
).join(
WorkflowConfig, WorkflowConfig.app_id == App.id
).filter(
WorkflowConfig.id == workflow_id
).first()
return result[0] if result else None
@staticmethod @staticmethod
def find_by_tenant( def find_by_tenant(
db: Session, db: Session,

View File

@@ -0,0 +1,32 @@
from datetime import datetime
from pydantic import BaseModel, field_serializer
from typing import Optional
from app.core.api_key_utils import datetime_to_timestamp
class HomeStatistics(BaseModel):
"""首页统计数据"""
total_models: int
new_models_this_month: int
active_workspaces: int
new_workspaces_this_month: int
total_users: int
new_users_this_month: int
running_apps: int
new_apps_this_week: int
class WorkspaceInfo(BaseModel):
"""工作空间信息"""
id: str
name: str
icon: Optional[str]
description: Optional[str]
app_count: int
user_count: int
created_at: datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime) -> Optional[int]:
return datetime_to_timestamp(v)

View File

@@ -0,0 +1,67 @@
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from uuid import UUID
from app.repositories.home_page_repository import HomePageRepository
from app.schemas.home_page_schema import HomeStatistics, WorkspaceInfo
class HomePageService:
@staticmethod
def get_home_statistics(db: Session, tenant_id: UUID) -> HomeStatistics:
"""获取首页统计数据"""
# 计算时间范围
now = datetime.now()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
week_start = now - timedelta(days=now.weekday())
week_start = week_start.replace(hour=0, minute=0, second=0, microsecond=0)
# 获取各项统计数据
total_models, new_models_this_month = HomePageRepository.get_model_statistics(
db, tenant_id, month_start
)
active_workspaces, new_workspaces_this_month = HomePageRepository.get_workspace_statistics(
db, tenant_id, month_start
)
total_users, new_users_this_month = HomePageRepository.get_user_statistics(
db, tenant_id, month_start
)
running_apps, new_apps_this_week = HomePageRepository.get_app_statistics(
db, tenant_id, week_start
)
return HomeStatistics(
total_models=total_models,
new_models_this_month=new_models_this_month,
active_workspaces=active_workspaces,
new_workspaces_this_month=new_workspaces_this_month,
total_users=total_users,
new_users_this_month=new_users_this_month,
running_apps=running_apps,
new_apps_this_week=new_apps_this_week
)
@staticmethod
def get_workspace_list(db: Session, tenant_id: UUID) -> list[WorkspaceInfo]:
"""获取工作空间列表(优化版本)"""
workspaces, app_count_dict, user_count_dict= HomePageRepository.get_workspaces_with_counts(
db, tenant_id
)
workspace_list = []
for workspace in workspaces:
workspace_info = WorkspaceInfo(
id=str(workspace.id),
name=workspace.name,
icon=workspace.icon,
description=workspace.description,
app_count=app_count_dict.get(workspace.id, 0),
user_count=user_count_dict.get(workspace.id, 0),
created_at=workspace.created_at
)
workspace_list.append(workspace_info)
return workspace_list

View File

@@ -297,6 +297,165 @@ class ToolService:
self.db.commit() self.db.commit()
logger.info(f"租户 {tenant_id} 内置工具初始化完成") logger.info(f"租户 {tenant_id} 内置工具初始化完成")
async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]:
"""获取工具的所有方法
Args:
tool_id: 工具ID
tenant_id: 租户ID
Returns:
方法列表或None
"""
config = self._get_tool_config(tool_id, tenant_id)
if not config:
return None
try:
if config.tool_type == ToolType.BUILTIN.value:
return await self._get_builtin_tool_methods(config)
elif config.tool_type == ToolType.CUSTOM.value:
return await self._get_custom_tool_methods(config)
elif config.tool_type == ToolType.MCP.value:
return await self._get_mcp_tool_methods(config)
else:
return []
except Exception as e:
logger.error(f"获取工具方法失败: {tool_id}, {e}")
return []
async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取内置工具的方法"""
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
return []
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return []
# 检查是否有operation参数
operation_param = None
for param in tool_instance.parameters:
if param.name == "operation" and param.enum:
operation_param = param
break
if operation_param:
# 有多个操作
methods = []
for operation in operation_param.enum:
methods.append({
"method_id": f"{config.name}_{operation}",
"name": operation,
"description": f"{config.description} - {operation}",
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
})
return methods
else:
# 只有一个方法
return [{
"method_id": config.name,
"name": config.name,
"description": config.description,
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
}]
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取自定义工具的方法"""
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
if not custom_config:
return []
try:
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
parser = OpenAPISchemaParser()
# 解析schema
if custom_config.schema_content:
success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json")
elif custom_config.schema_url:
success, schema, error = await parser.parse_from_url(custom_config.schema_url)
else:
return []
if not success:
return []
# 提取操作
tool_info = parser.extract_tool_info(schema)
operations = tool_info.get("operations", {})
methods = []
for operation_id, operation in operations.items():
# 生成参数列表
parameters = []
# 路径和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
parameters.append({
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default")
})
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema_props = request_body.get("schema", {}).get("properties", {})
required_props = request_body.get("schema", {}).get("required", [])
for prop_name, prop_schema in schema_props.items():
parameters.append({
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in required_props,
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default")
})
methods.append({
"method_id": operation_id,
"name": operation.get("summary", operation_id),
"description": operation.get("description", ""),
"method": operation.get("method", "GET"),
"path": operation.get("path", "/"),
"parameters": parameters
})
return methods
except Exception as e:
logger.error(f"解析自定义工具schema失败: {e}")
return []
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取MCP工具的方法"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return []
available_tools = mcp_config.available_tools or []
if not available_tools:
return []
methods = []
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": [] # MCP工具参数需要动态获取
})
return methods
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]: def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
"""获取工具统计信息""" """获取工具统计信息"""
try: try:

38
api_key_mcp_server.py Normal file
View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""API Key认证MCP服务器"""
from fastapi import FastAPI, HTTPException, Depends, Header
from typing import Optional
import uvicorn
from mcp_base import MCPRequest, handle_mcp_request, TOOLS
app = FastAPI(title="API Key MCP Server", version="1.0.0")
# API Key配置
API_KEYS = {"test-api-key", "demo-key-123"}
def verify_api_key(x_api_key: Optional[str] = Header(None)):
"""验证API Key"""
if x_api_key and x_api_key in API_KEYS:
return True
raise HTTPException(status_code=401, detail="Invalid API Key")
@app.get("/")
async def root():
return {"name": "API Key MCP Server", "version": "1.0.0", "auth_type": "api_key"}
@app.get("/health")
async def health():
return {"status": "healthy", "tools": len(TOOLS), "auth_type": "api_key"}
@app.post("/mcp")
async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_api_key)):
return await handle_mcp_request(request, "API Key MCP Server")
if __name__ == "__main__":
print("启动API Key认证MCP服务器...")
print("访问 http://localhost:8004 查看服务状态")
print("MCP端点: http://localhost:8004/mcp")
print("认证方式: API Key (Header: X-API-Key)")
print("测试API Keys: test-api-key, demo-key-123")
uvicorn.run(app, host="0.0.0.0", port=8004)

45
basic_auth_mcp_server.py Normal file
View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""Basic Auth认证MCP服务器"""
from fastapi import FastAPI, HTTPException, Depends, Header
from typing import Optional
import uvicorn
import base64
from mcp_base import MCPRequest, handle_mcp_request, TOOLS
app = FastAPI(title="Basic Auth MCP Server", version="1.0.0")
# Basic Auth配置
BASIC_AUTH_USERS = {"admin": "password", "user": "secret"}
def verify_basic_auth(authorization: Optional[str] = Header(None)):
"""验证Basic Auth"""
if authorization and authorization.startswith("Basic "):
try:
credentials = base64.b64decode(authorization.split(" ")[1]).decode()
username, password = credentials.split(":", 1)
if username in BASIC_AUTH_USERS and BASIC_AUTH_USERS[username] == password:
return True
except:
pass
raise HTTPException(status_code=401, detail="Invalid Basic Auth")
@app.get("/")
async def root():
return {"name": "Basic Auth MCP Server", "version": "1.0.0", "auth_type": "basic_auth"}
@app.get("/health")
async def health():
return {"status": "healthy", "tools": len(TOOLS), "auth_type": "basic_auth"}
@app.post("/mcp")
async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_basic_auth)):
return await handle_mcp_request(request, "Basic Auth MCP Server")
if __name__ == "__main__":
print("启动Basic Auth认证MCP服务器...")
print("访问 http://localhost:8006 查看服务状态")
print("MCP端点: http://localhost:8006/mcp")
print("认证方式: Basic Auth (Header: Authorization: Basic <base64>)")
print("测试用户: admin:password, user:secret")
uvicorn.run(app, host="0.0.0.0", port=8006)

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""Bearer Token认证MCP服务器"""
from fastapi import FastAPI, HTTPException, Depends, Header
from typing import Optional
import uvicorn
from mcp_base import MCPRequest, handle_mcp_request, TOOLS
app = FastAPI(title="Bearer Token MCP Server", version="1.0.0")
# Bearer Token配置
BEARER_TOKENS = {"bearer-token-123", "demo-bearer-token"}
def verify_bearer_token(authorization: Optional[str] = Header(None)):
"""验证Bearer Token"""
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
if token in BEARER_TOKENS:
return True
raise HTTPException(status_code=401, detail="Invalid Bearer Token")
@app.get("/")
async def root():
return {"name": "Bearer Token MCP Server", "version": "1.0.0", "auth_type": "bearer_token"}
@app.get("/health")
async def health():
return {"status": "healthy", "tools": len(TOOLS), "auth_type": "bearer_token"}
@app.post("/mcp")
async def mcp_handler(request: MCPRequest, _: bool = Depends(verify_bearer_token)):
return await handle_mcp_request(request, "Bearer Token MCP Server")
if __name__ == "__main__":
print("启动Bearer Token认证MCP服务器...")
print("访问 http://localhost:8005 查看服务状态")
print("MCP端点: http://localhost:8005/mcp")
print("认证方式: Bearer Token (Header: Authorization: Bearer <token>)")
print("测试Bearer Tokens: bearer-token-123, demo-bearer-token")
uvicorn.run(app, host="0.0.0.0", port=8005)

111
mcp_base.py Normal file
View File

@@ -0,0 +1,111 @@
#!/usr/bin/env python3
"""MCP服务器基础模块 - 共享的模型和处理逻辑"""
from pydantic import BaseModel
from typing import Dict, Any
class MCPRequest(BaseModel):
jsonrpc: str = "2.0"
id: str
method: str
params: Dict[str, Any] = {}
class MCPResponse(BaseModel):
jsonrpc: str = "2.0"
id: str
result: Any = None
error: Dict[str, Any] = None
# 工具定义
TOOLS = [
{
"name": "calculator",
"description": "简单计算器",
"inputSchema": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式"}
},
"required": ["expression"]
}
},
{
"name": "echo",
"description": "回显工具",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "要回显的消息"}
},
"required": ["message"]
}
}
]
async def handle_mcp_request(request: MCPRequest, server_name: str = "MCP Server"):
"""处理MCP请求"""
try:
if request.method == "initialize":
return MCPResponse(
id=request.id,
result={
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": True}},
"serverInfo": {"name": server_name, "version": "1.0.0"}
}
)
elif request.method == "tools/list":
return MCPResponse(
id=request.id,
result={"tools": TOOLS}
)
elif request.method == "tools/call":
tool_name = request.params.get("name")
arguments = request.params.get("arguments", {})
if tool_name == "calculator":
try:
expression = arguments.get("expression", "")
result = eval(expression)
return MCPResponse(
id=request.id,
result={"content": [{"type": "text", "text": f"结果: {result}"}]}
)
except Exception as e:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"计算错误: {str(e)}"}
)
elif tool_name == "echo":
message = arguments.get("message", "")
return MCPResponse(
id=request.id,
result={"content": [{"type": "text", "text": f"Echo: {message}"}]}
)
else:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"未知工具: {tool_name}"}
)
elif request.method == "ping":
return MCPResponse(
id=request.id,
result={"status": "pong"}
)
else:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"未知方法: {request.method}"}
)
except Exception as e:
return MCPResponse(
id=request.id,
error={"code": -1, "message": str(e)}
)

130
simple_mcp_server.py Normal file
View File

@@ -0,0 +1,130 @@
#!/usr/bin/env python3
"""简化的MCP服务器 - 用于测试MCP工具集成"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any, List
import uvicorn
app = FastAPI(title="Simple MCP Server", version="1.0.0")
class MCPRequest(BaseModel):
jsonrpc: str = "2.0"
id: str
method: str
params: Dict[str, Any] = {}
class MCPResponse(BaseModel):
jsonrpc: str = "2.0"
id: str
result: Any = None
error: Dict[str, Any] = None
# 可用工具定义
TOOLS = [
{
"name": "calculator",
"description": "简单计算器",
"inputSchema": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式"}
},
"required": ["expression"]
}
},
{
"name": "echo",
"description": "回显工具",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "要回显的消息"}
},
"required": ["message"]
}
}
]
@app.get("/")
async def root():
return {"name": "Simple MCP Server", "version": "1.0.0"}
@app.get("/health")
async def health():
return {"status": "healthy", "tools": len(TOOLS)}
@app.post("/mcp")
async def mcp_handler(request: MCPRequest):
"""处理MCP请求"""
try:
if request.method == "initialize":
return MCPResponse(
id=request.id,
result={
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": True}},
"serverInfo": {"name": "Simple MCP Server", "version": "1.0.0"}
}
)
elif request.method == "tools/list":
return MCPResponse(
id=request.id,
result={"tools": TOOLS}
)
elif request.method == "tools/call":
tool_name = request.params.get("name")
arguments = request.params.get("arguments", {})
if tool_name == "calculator":
try:
expression = arguments.get("expression", "")
result = eval(expression) # 注意生产环境不要用eval
return MCPResponse(
id=request.id,
result={"content": [{"type": "text", "text": f"结果: {result}"}]}
)
except Exception as e:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"计算错误: {str(e)}"}
)
elif tool_name == "echo":
message = arguments.get("message", "")
return MCPResponse(
id=request.id,
result={"content": [{"type": "text", "text": f"Echo: {message}"}]}
)
else:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"未知工具: {tool_name}"}
)
elif request.method == "ping":
return MCPResponse(
id=request.id,
result={"status": "pong"}
)
else:
return MCPResponse(
id=request.id,
error={"code": -1, "message": f"未知方法: {request.method}"}
)
except Exception as e:
return MCPResponse(
id=request.id,
error={"code": -1, "message": str(e)}
)
if __name__ == "__main__":
print("启动简化MCP服务器...")
print("访问 http://localhost:8002 查看服务状态")
print("MCP端点: http://localhost:8002/mcp")
uvicorn.run(app, host="0.0.0.0", port=8002)