Merge branch 'develop' into feature/20251219_myh
# Conflicts: # api/app/core/workflow/executor.py # api/app/core/workflow/nodes/node_factory.py # api/app/core/workflow/nodes/question_classifier/node.py
This commit is contained in:
@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
@@ -33,5 +35,7 @@ __all__ = [
|
||||
"AssignerNode",
|
||||
"HttpRequestNode",
|
||||
"JinjaRenderNode",
|
||||
"ParameterExtractorNode"
|
||||
"ParameterExtractorNode",
|
||||
"QuestionClassifierNode",
|
||||
"ToolNode"
|
||||
]
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
__all__ = [
|
||||
@@ -45,4 +46,5 @@ __all__ = [
|
||||
"LoopNodeConfig",
|
||||
"IterationNodeConfig",
|
||||
"QuestionClassifierNodeConfig"
|
||||
"ToolNodeConfig"
|
||||
]
|
||||
|
||||
@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,7 +45,8 @@ WorkflowNode = Union[
|
||||
CycleGraphNode,
|
||||
BreakNode,
|
||||
ParameterExtractorNode,
|
||||
QuestionClassifierNode
|
||||
QuestionClassifierNode,
|
||||
ToolNode
|
||||
]
|
||||
|
||||
|
||||
@@ -73,6 +75,7 @@ class NodeFactory:
|
||||
NodeType.ITERATION: CycleGraphNode,
|
||||
NodeType.BREAK: BreakNode,
|
||||
NodeType.CYCLE_START: StartNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig):
|
||||
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
description="用户提示词模板"
|
||||
)
|
||||
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")
|
||||
|
||||
@@ -12,32 +12,36 @@ from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CASE_PREFIX = "CASE"
|
||||
DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
"""问题分类器节点"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
|
||||
return RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
@@ -48,47 +52,72 @@ class QuestionClassifierNode(BaseNode):
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _build_category_case_map(self) -> dict[str, str]:
|
||||
"""
|
||||
预构建 分类名称 -> CASE标识 的映射字典
|
||||
示例:{"产品咨询": "CASE1", "售后问题": "CASE2"}
|
||||
"""
|
||||
category_map = {}
|
||||
categories = self.typed_config.categories or []
|
||||
for idx, class_item in enumerate(categories, start=1):
|
||||
category_name = class_item.class_name.strip()
|
||||
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
"""执行问题分类"""
|
||||
question = self.typed_config.input_variable
|
||||
|
||||
supplement_prompt = ""
|
||||
if self.typed_config.user_supplement_prompt is not None:
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt
|
||||
|
||||
category_names = [class_item.class_name for class_item in self.typed_config.categories]
|
||||
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
categories = self.typed_config.categories or []
|
||||
category_names = [class_item.class_name.strip() for class_item in categories]
|
||||
category_count = len(category_names)
|
||||
|
||||
if not question:
|
||||
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
|
||||
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||
|
||||
llm = self._get_llm_instance()
|
||||
try:
|
||||
llm = self._get_llm_instance()
|
||||
|
||||
# 渲染用户提示词模板,支持工作流变量
|
||||
user_prompt = self._render_template(
|
||||
self.typed_config.user_prompt.format(
|
||||
question=question,
|
||||
categories=", ".join(category_names),
|
||||
supplement_prompt=supplement_prompt
|
||||
),
|
||||
state
|
||||
)
|
||||
# 渲染用户提示词模板,支持工作流变量
|
||||
user_prompt = self._render_template(
|
||||
self.typed_config.user_prompt.format(
|
||||
question=question,
|
||||
categories=", ".join(category_names),
|
||||
supplement_prompt=supplement_prompt
|
||||
),
|
||||
state
|
||||
)
|
||||
|
||||
messages = [
|
||||
("system", self.typed_config.system_prompt),
|
||||
("user", user_prompt),
|
||||
]
|
||||
messages = [
|
||||
("system", self.typed_config.system_prompt),
|
||||
("user", user_prompt),
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
result = response.content.strip()
|
||||
response = await llm.ainvoke(messages)
|
||||
result = response.content.strip()
|
||||
|
||||
if result in category_names:
|
||||
category = result
|
||||
else:
|
||||
logger.warning(f"LLM返回了未知类别: {result}")
|
||||
category = category_names[0] if category_names else "unknown"
|
||||
if result in category_names:
|
||||
category = result
|
||||
else:
|
||||
logger.warning(f"LLM返回了未知类别: {result}")
|
||||
category = category_names[0] if category_names else "unknown"
|
||||
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
|
||||
return {self.typed_config.output_variable: category}
|
||||
return f"CASE{category_names.index(category) + 1}"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||
exc_info=True # 打印堆栈信息,便于调试
|
||||
)
|
||||
# 异常时返回默认分支,保证工作流容错性
|
||||
if category_count > 0:
|
||||
return DEFAULT_EMPTY_QUESTION_CASE
|
||||
return "unknown"
|
||||
|
||||
4
api/app/core/workflow/nodes/tool/__init__.py
Normal file
4
api/app/core/workflow/nodes/tool/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.tool.node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||
9
api/app/core/workflow/nodes/tool/config.py
Normal file
9
api/app/core/workflow/nodes/tool/config.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import Field
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class ToolNodeConfig(BaseNodeConfig):
|
||||
"""工具节点配置"""
|
||||
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
72
api/app/core/workflow/nodes/tool/node.py
Normal file
72
api/app/core/workflow/nodes/tool/node.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||
if workflow_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
# return {"error": "缺少租户ID"}
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
print(self.typed_config.tool_id)
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
tool_service = ToolService(db)
|
||||
result = await tool_service.execute_tool(
|
||||
tool_id=self.typed_config.tool_id,
|
||||
parameters=rendered_parameters,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
print(result)
|
||||
if result.success:
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
"success": True,
|
||||
"data": result.data,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
else:
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
Reference in New Issue
Block a user