diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2e5758eb..3cece96b 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -33,7 +33,7 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) @@ -45,17 +45,17 @@ class EndNode(BaseNode): total_nodes = len(node_outputs) logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - + return output - + def _extract_referenced_nodes(self, template: str) -> list[str]: """从模板中提取引用的节点 ID - + 例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] - + Args: template: 模板字符串 - + Returns: 引用的节点 ID 列表 """ @@ -63,44 +63,44 @@ class EndNode(BaseNode): pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' matches = re.findall(pattern, template) return list(set(matches)) # 去重 - + def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: """解析模板,分离静态文本和动态引用 - + 例如:'你好 {{llm.output}}, 这是后缀' 返回:[ {"type": "static", "content": "你好 "}, {"type": "dynamic", "node_id": "llm", "field": "output"}, {"type": "static", "content": ", 这是后缀"} ] - + Args: template: 模板字符串 state: 工作流状态 - + Returns: 模板部分列表 """ import re - + parts = [] last_end = 0 - + # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) pattern = r'\{\{\s*([^}]+?)\s*\}\}' - + for match in re.finditer(pattern, template): start, end = match.span() - + # 添加前面的静态文本 if start > last_end: static_text = template[last_end:start] if static_text: parts.append({"type": "static", "content": static_text}) - + # 解析动态引用 ref = match.group(1).strip() - + # 检查是否是节点引用(如 llm.output 或 llm_qa.output) if '.' in ref: node_id, field = ref.split('.', 1) @@ -115,62 +115,62 @@ class EndNode(BaseNode): # 直接渲染这部分 rendered = self._render_template(f"{{{{{ref}}}}}", state) parts.append({"type": "static", "content": rendered}) - + last_end = end - + # 添加最后的静态文本 if last_end < len(template): static_text = template[last_end:] if static_text: parts.append({"type": "static", "content": static_text}) - + return parts - + async def execute_stream(self, state: WorkflowState): """流式执行 end 节点业务逻辑 - + 智能输出策略: 1. 检测模板中是否引用了直接上游节点 2. 如果引用了,只输出该引用**之后**的部分(后缀) 3. 前缀和引用内容已经在上游节点流式输出时发送了 - + 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' - 直接上游节点是 llm_qa - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - LLM 内容在 LLM 节点流式输出 - End 节点只输出 ' lalalalala a'(后缀,一次性输出) - + Args: state: 工作流状态 - + Yields: 完成标记 """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") - + # 获取配置的输出模板 output_template = self.config.get("output") - + if not output_template: output = "工作流已完成" yield {"__final__": True, "result": output} return - + # 找到直接上游节点 direct_upstream_nodes = [] for edge in self.workflow_config.get("edges", []): if edge.get("target") == self.node_id: source_node_id = edge.get("source") direct_upstream_nodes.append(source_node_id) - + logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") - + # 解析模板部分 parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") for i, part in enumerate(parts): logger.info(f"[模板解析] part[{i}]: {part}") - + # 找到第一个引用直接上游节点的动态引用 upstream_ref_index = None for i, part in enumerate(parts): @@ -178,12 +178,12 @@ class EndNode(BaseNode): upstream_ref_index = i logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") break - + if upstream_ref_index is None: # 没有引用直接上游节点,输出完整模板内容 output = self._render_template(output_template, state) logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") - + # 通过 writer 发送完整内容(作为一个 message chunk) from langgraph.config import get_stream_writer writer = get_stream_writer() @@ -196,14 +196,14 @@ class EndNode(BaseNode): "is_suffix": False }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") - + # yield 完成标记 yield {"__final__": True, "result": output} return - + # 有引用直接上游节点,只输出该引用之后的部分(后缀) logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") - + # 收集后缀部分 suffix_parts = [] logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") @@ -214,7 +214,7 @@ class EndNode(BaseNode): # 静态文本 logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") suffix_parts.append(part["content"]) - + elif part["type"] == "dynamic": # Other dynamic references (if there are multiple references) node_id = part["node_id"] @@ -229,21 +229,21 @@ class EndNode(BaseNode): except Exception as e: logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") content = "" - + # Convert to string if not None suffix_parts.append(str(content) if content is not None else "") # 拼接后缀 suffix = "".join(suffix_parts) - + # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state) - + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") - + if suffix: logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") # 一次性输出后缀(作为单个 chunk) @@ -266,8 +266,8 @@ class EndNode(BaseNode): # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") - + # yield 完成标记(包含完整输出) yield {"__final__": True, "result": full_output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 8f809923..65826d84 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig from app.db import get_db_context +from app.models import ModelType from app.services.model_service import ModelConfigService from app.core.exceptions import BusinessException @@ -136,7 +137,7 @@ class LLMNode(BaseNode): base_url=api_base, extra_params=extra_params ), - type=model_type + type=ModelType(model_type) ) logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 1abace67..2ae31d4d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,6 +7,7 @@ import logging from typing import Any, Union +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.assigner import AssignerNode logger = logging.getLogger(__name__) @@ -26,6 +28,8 @@ WorkflowNode = Union[ IfElseNode, AgentNode, TransformNode, + AssignerNode, + KnowledgeRetrievalNode, ] @@ -42,7 +46,9 @@ class NodeFactory: NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, - NodeType.IF_ELSE: IfElseNode + NodeType.IF_ELSE: IfElseNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.ASSIGNER: AssignerNode, } @classmethod @@ -82,10 +88,6 @@ class NodeFactory: """ node_type = node_config.get("type") - # 跳过条件节点(由 LangGraph 处理) - if node_type == "condition": - return None - # 获取节点类 node_class = cls._node_types.get(node_type) if not node_class: diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 1f589dab..0f97c349 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -10,7 +10,10 @@ """ import logging -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from app.core.workflow.nodes import WorkflowState logger = logging.getLogger(__name__) @@ -82,7 +85,7 @@ class VariablePool: >>> pool.set(["conv", "user_name"], "张三") """ - def __init__(self, state: dict[str, Any]): + def __init__(self, state: "WorkflowState"): """初始化变量池 Args: diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 91c1d9c7..2e60ef1c 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -15,25 +15,6 @@ class ModelType(StrEnum): EMBEDDING = "embedding" RERANK = "rerank" - @classmethod - def from_str(cls, value: str) -> "ModelType": - """ - Get a ModelType enum instance from a string value. - - Args: - value (str): The string representation of the model type. - - Returns: - ModelType: The corresponding ModelType enum object. - - Raises: - ValueError: If the given value does not match any ModelType. - """ - try: - return cls(value) - except ValueError: - raise ValueError(f"Invalid ModelType: {value}") - class ModelProvider(StrEnum): """模型提供商枚举""" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 5355474f..6af794b1 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -169,7 +169,7 @@ class PromptOptimizerService: provider=api_config.provider, api_key=api_config.api_key, base_url=api_config.api_base - ), type=ModelType.from_str(model_config.type)) + ), type=ModelType(model_config.type)) # build message messages = [