style(workflow): remove unnecessary indentation
This commit is contained in:
@@ -33,7 +33,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -45,17 +45,17 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
@@ -63,44 +63,44 @@ class EndNode(BaseNode):
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
@@ -115,62 +115,62 @@ class EndNode(BaseNode):
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
|
||||
last_end = end
|
||||
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
@@ -178,12 +178,12 @@ class EndNode(BaseNode):
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,输出完整模板内容
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
|
||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
@@ -196,14 +196,14 @@ class EndNode(BaseNode):
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}")
|
||||
@@ -214,7 +214,7 @@ class EndNode(BaseNode):
|
||||
# 静态文本
|
||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
@@ -229,21 +229,21 @@ class EndNode(BaseNode):
|
||||
except Exception as e:
|
||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
||||
content = ""
|
||||
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
||||
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
@@ -266,8 +266,8 @@ class EndNode(BaseNode):
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +28,8 @@ WorkflowNode = Union[
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
KnowledgeRetrievalNode,
|
||||
]
|
||||
|
||||
|
||||
@@ -42,7 +46,9 @@ class NodeFactory:
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -82,10 +88,6 @@ class NodeFactory:
|
||||
"""
|
||||
node_type = node_config.get("type")
|
||||
|
||||
# 跳过条件节点(由 LangGraph 处理)
|
||||
if node_type == "condition":
|
||||
return None
|
||||
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
|
||||
@@ -10,7 +10,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -82,7 +85,7 @@ class VariablePool:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict[str, Any]):
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
|
||||
@@ -15,25 +15,6 @@ class ModelType(StrEnum):
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> "ModelType":
|
||||
"""
|
||||
Get a ModelType enum instance from a string value.
|
||||
|
||||
Args:
|
||||
value (str): The string representation of the model type.
|
||||
|
||||
Returns:
|
||||
ModelType: The corresponding ModelType enum object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given value does not match any ModelType.
|
||||
"""
|
||||
try:
|
||||
return cls(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid ModelType: {value}")
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
"""模型提供商枚举"""
|
||||
|
||||
@@ -169,7 +169,7 @@ class PromptOptimizerService:
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
), type=ModelType.from_str(model_config.type))
|
||||
), type=ModelType(model_config.type))
|
||||
|
||||
# build message
|
||||
messages = [
|
||||
|
||||
Reference in New Issue
Block a user