Merge branch 'develop' into feature/20251219_myh

# Conflicts:
#	api/app/core/workflow/executor.py
#	api/app/core/workflow/nodes/node_factory.py
#	api/app/core/workflow/nodes/question_classifier/node.py
This commit is contained in:
mengyonghao
2026-01-05 11:10:01 +08:00
284 changed files with 21282 additions and 2779 deletions

View File

@@ -3,7 +3,7 @@ import secrets
from typing import Optional, Union
from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType
from app.models.api_key_model import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse

View File

@@ -48,7 +48,6 @@ class RAGExcelParser:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine")
print("lxc1")
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
@@ -215,19 +214,35 @@ class RAGExcelParser:
continue
if not rows:
continue
# 获取表头
ti = list(rows[0])
for r in list(rows[1:]):
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
header_fields = []
for cell in ti:
if cell.value: # 只添加有值的表头
header_fields.append(str(cell.value))
# 如果有数据行,处理数据行;否则只处理表头
data_rows = rows[1:]
if data_rows:
for r in data_rows:
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
else:
# 只有表头的情况
if header_fields:
line = "; ".join(header_fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
return res
@staticmethod

View File

@@ -1,7 +1,7 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .langchain_adapter import LangchainAdapter
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from app.core.tools.langchain_adapter import LangchainAdapter
# 可选导入,避免导入错误
try:

View File

@@ -193,7 +193,7 @@ class BaseTool(ABC):
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):

View File

@@ -1,11 +1,11 @@
"""内置工具模块"""
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
from app.core.tools.builtin.base import BuiltinTool
from app.core.tools.builtin.datetime_tool import DateTimeTool
from app.core.tools.builtin.json_tool import JsonTool
from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool
from app.core.tools.builtin.mineru_tool import MinerUTool
from app.core.tools.builtin.textin_tool import TextInTool
__all__ = [
"BuiltinTool",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class BaiduSearchTool(BuiltinTool):

View File

@@ -5,7 +5,7 @@ from typing import List
import pytz
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class DateTimeTool(BuiltinTool):
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
),
ToolParameter(
name="input_value",

View File

@@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class JsonTool(BuiltinTool):
@@ -29,8 +29,7 @@ class JsonTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
"extract", "insert", "replace", "delete", "parse"]
enum=["insert", "replace", "delete", "parse"]
),
ToolParameter(
name="input_data",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class MinerUTool(BuiltinTool):

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
from app.core.tools.builtin.base import BuiltinTool
class TextInTool(BuiltinTool):

View File

@@ -1,8 +1,8 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
from app.core.tools.custom.base import CustomTool
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
from app.core.tools.custom.auth_manager import AuthManager
__all__ = [
"CustomTool",

View File

@@ -1,8 +1,8 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
from app.core.tools.mcp.base import MCPTool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.service_manager import MCPServiceManager
__all__ = [
"MCPTool",

View File

@@ -1,7 +1,6 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool

View File

@@ -204,7 +204,7 @@ class MCPClient:
)
init_response = json.loads(response)
if "error" in init_response:
if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
@@ -325,7 +325,7 @@ class MCPClient:
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
logger = get_business_logger()

View File

@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
__all__ = [
"BaseNode",
@@ -33,5 +35,7 @@ __all__ = [
"AssignerNode",
"HttpRequestNode",
"JinjaRenderNode",
"ParameterExtractorNode"
"ParameterExtractorNode",
"QuestionClassifierNode",
"ToolNode"
]

View File

@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
@@ -45,4 +46,5 @@ __all__ = [
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
"ToolNodeConfig"
]

View File

@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
logger = logging.getLogger(__name__)
@@ -44,7 +45,8 @@ WorkflowNode = Union[
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode
QuestionClassifierNode,
ToolNode
]
@@ -73,6 +75,7 @@ class NodeFactory:
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode,
}
@classmethod

View File

@@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig):
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
description="用户提示词模板"
)
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")

View File

@@ -12,32 +12,36 @@ from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
DEFAULT_CASE_PREFIX = "CASE"
DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
base_url = api_config.api_base
model_type = config.type
return RedBearLLM(
RedBearModelConfig(
model_name=model_name,
@@ -48,47 +52,72 @@ class QuestionClassifierNode(BaseNode):
type=ModelType(model_type)
)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _build_category_case_map(self) -> dict[str, str]:
"""
预构建 分类名称 -> CASE标识 的映射字典
示例:{"产品咨询": "CASE1", "售后问题": "CASE2"}
"""
category_map = {}
categories = self.typed_config.categories or []
for idx, class_item in enumerate(categories, start=1):
category_name = class_item.class_name.strip()
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
"""执行问题分类"""
question = self.typed_config.input_variable
supplement_prompt = ""
if self.typed_config.user_supplement_prompt is not None:
supplement_prompt = self.typed_config.user_supplement_prompt
category_names = [class_item.class_name for class_item in self.typed_config.categories]
supplement_prompt = self.typed_config.user_supplement_prompt or ""
categories = self.typed_config.categories or []
category_names = [class_item.class_name.strip() for class_item in categories]
category_count = len(category_names)
if not question:
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
logger.warning(
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
llm = self._get_llm_instance()
try:
llm = self._get_llm_instance()
# 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template(
self.typed_config.user_prompt.format(
question=question,
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
)
# 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template(
self.typed_config.user_prompt.format(
question=question,
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
)
messages = [
("system", self.typed_config.system_prompt),
("user", user_prompt),
]
messages = [
("system", self.typed_config.system_prompt),
("user", user_prompt),
]
response = await llm.ainvoke(messages)
result = response.content.strip()
response = await llm.ainvoke(messages)
result = response.content.strip()
if result in category_names:
category = result
else:
logger.warning(f"LLM返回了未知类别: {result}")
category = category_names[0] if category_names else "unknown"
if result in category_names:
category = result
else:
logger.warning(f"LLM返回了未知类别: {result}")
category = category_names[0] if category_names else "unknown"
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return {self.typed_config.output_variable: category}
return f"CASE{category_names.index(category) + 1}"
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
exc_info=True # 打印堆栈信息,便于调试
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -0,0 +1,9 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -0,0 +1,72 @@
import logging
import uuid
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.services.tool_service import ToolService
from app.db import get_db_read
logger = logging.getLogger(__name__)
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
result = await tool_service.execute_tool(
tool_id=self.typed_config.tool_id,
parameters=rendered_parameters,
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
"success": True,
"data": result.data,
"execution_time": result.execution_time
}
else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}