style(workflow): remove unnecessary indentation

This commit is contained in:
mengyonghao
2025-12-22 16:18:25 +08:00
parent c15a987701
commit 75ee591202
6 changed files with 57 additions and 70 deletions

View File

@@ -33,7 +33,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
@@ -45,17 +45,17 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output
def _extract_referenced_nodes(self, template: str) -> list[str]: def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID """从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] 例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args: Args:
template: 模板字符串 template: 模板字符串
Returns: Returns:
引用的节点 ID 列表 引用的节点 ID 列表
""" """
@@ -63,44 +63,44 @@ class EndNode(BaseNode):
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template) matches = re.findall(pattern, template)
return list(set(matches)) # 去重 return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用 """解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀' 例如:'你好 {{llm.output}}, 这是后缀'
返回:[ 返回:[
{"type": "static", "content": "你好 "}, {"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"}, {"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"} {"type": "static", "content": ", 这是后缀"}
] ]
Args: Args:
template: 模板字符串 template: 模板字符串
state: 工作流状态 state: 工作流状态
Returns: Returns:
模板部分列表 模板部分列表
""" """
import re import re
parts = [] parts = []
last_end = 0 last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}' pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template): for match in re.finditer(pattern, template):
start, end = match.span() start, end = match.span()
# 添加前面的静态文本 # 添加前面的静态文本
if start > last_end: if start > last_end:
static_text = template[last_end:start] static_text = template[last_end:start]
if static_text: if static_text:
parts.append({"type": "static", "content": static_text}) parts.append({"type": "static", "content": static_text})
# 解析动态引用 # 解析动态引用
ref = match.group(1).strip() ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output # 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref: if '.' in ref:
node_id, field = ref.split('.', 1) node_id, field = ref.split('.', 1)
@@ -115,62 +115,62 @@ class EndNode(BaseNode):
# 直接渲染这部分 # 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state) rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered}) parts.append({"type": "static", "content": rendered})
last_end = end last_end = end
# 添加最后的静态文本 # 添加最后的静态文本
if last_end < len(template): if last_end < len(template):
static_text = template[last_end:] static_text = template[last_end:]
if static_text: if static_text:
parts.append({"type": "static", "content": static_text}) parts.append({"type": "static", "content": static_text})
return parts return parts
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑 """流式执行 end 节点业务逻辑
智能输出策略: 智能输出策略:
1. 检测模板中是否引用了直接上游节点 1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀) 2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了 3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa - 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出 - LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出) - End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args: Args:
state: 工作流状态 state: 工作流状态
Yields: Yields:
完成标记 完成标记
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
if not output_template: if not output_template:
output = "工作流已完成" output = "工作流已完成"
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 找到直接上游节点 # 找到直接上游节点
direct_upstream_nodes = [] direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []): for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id: if edge.get("target") == self.node_id:
source_node_id = edge.get("source") source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id) direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分 # 解析模板部分
parts = self._parse_template_parts(output_template, state) parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts): for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}") logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用 # 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None upstream_ref_index = None
for i, part in enumerate(parts): for i, part in enumerate(parts):
@@ -178,12 +178,12 @@ class EndNode(BaseNode):
upstream_ref_index = i upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break break
if upstream_ref_index is None: if upstream_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容 # 没有引用直接上游节点,输出完整模板内容
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk # 通过 writer 发送完整内容(作为一个 message chunk
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
writer = get_stream_writer() writer = get_stream_writer()
@@ -196,14 +196,14 @@ class EndNode(BaseNode):
"is_suffix": False "is_suffix": False
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记 # yield 完成标记
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 有引用直接上游节点,只输出该引用之后的部分(后缀) # 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分 # 收集后缀部分
suffix_parts = [] suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}") logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}")
@@ -214,7 +214,7 @@ class EndNode(BaseNode):
# 静态文本 # 静态文本
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
suffix_parts.append(part["content"]) suffix_parts.append(part["content"])
elif part["type"] == "dynamic": elif part["type"] == "dynamic":
# Other dynamic references (if there are multiple references) # Other dynamic references (if there are multiple references)
node_id = part["node_id"] node_id = part["node_id"]
@@ -229,21 +229,21 @@ class EndNode(BaseNode):
except Exception as e: except Exception as e:
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
content = "" content = ""
# Convert to string if not None # Convert to string if not None
suffix_parts.append(str(content) if content is not None else "") suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀 # 拼接后缀
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state) full_output = self._render_template(output_template, state)
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
if suffix: if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk # 一次性输出后缀(作为单个 chunk
@@ -266,8 +266,8 @@ class EndNode(BaseNode):
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出) # yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output} yield {"__final__": True, "result": full_output}

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
base_url=api_base, base_url=api_base,
extra_params=extra_params extra_params=extra_params
), ),
type=model_type type=ModelType(model_type)
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")

View File

@@ -7,6 +7,7 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.assigner import AssignerNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +28,8 @@ WorkflowNode = Union[
IfElseNode, IfElseNode,
AgentNode, AgentNode,
TransformNode, TransformNode,
AssignerNode,
KnowledgeRetrievalNode,
] ]
@@ -42,7 +46,9 @@ class NodeFactory:
NodeType.LLM: LLMNode, NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode, NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode, NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
} }
@classmethod @classmethod
@@ -82,10 +88,6 @@ class NodeFactory:
""" """
node_type = node_config.get("type") node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类 # 获取节点类
node_class = cls._node_types.get(node_type) node_class = cls._node_types.get(node_type)
if not node_class: if not node_class:

View File

@@ -10,7 +10,10 @@
""" """
import logging import logging
from typing import Any from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,7 +85,7 @@ class VariablePool:
>>> pool.set(["conv", "user_name"], "张三") >>> pool.set(["conv", "user_name"], "张三")
""" """
def __init__(self, state: dict[str, Any]): def __init__(self, state: "WorkflowState"):
"""初始化变量池 """初始化变量池
Args: Args:

View File

@@ -15,25 +15,6 @@ class ModelType(StrEnum):
EMBEDDING = "embedding" EMBEDDING = "embedding"
RERANK = "rerank" RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
"""模型提供商枚举""" """模型提供商枚举"""

View File

@@ -169,7 +169,7 @@ class PromptOptimizerService:
provider=api_config.provider, provider=api_config.provider,
api_key=api_config.api_key, api_key=api_config.api_key,
base_url=api_config.api_base base_url=api_config.api_base
), type=ModelType.from_str(model_config.type)) ), type=ModelType(model_config.type))
# build message # build message
messages = [ messages = [