Compare commits
2 Commits
feature/me
...
feat/wxy-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cef33fce0d | ||
|
|
d9f08860bc |
@@ -2,6 +2,7 @@
|
|||||||
# Author: Eternity
|
# Author: Eternity
|
||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/10 13:33
|
# @Time : 2026/2/10 13:33
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -141,9 +142,10 @@ class GraphBuilder:
|
|||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
branch_nodes.append(
|
if node_info.get("branch") is not None:
|
||||||
(node_info["id"], node_info["branch"])
|
branch_nodes.append(
|
||||||
)
|
(node_info["id"], node_info["branch"])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||||
output_nodes.append(node_info["id"])
|
output_nodes.append(node_info["id"])
|
||||||
@@ -314,9 +316,12 @@ class GraphBuilder:
|
|||||||
for idx in range(len(related_edge)):
|
for idx in range(len(related_edge)):
|
||||||
# Generate a condition expression for each edge
|
# Generate a condition expression for each edge
|
||||||
# Used later to determine which branch to take based on the node's output
|
# Used later to determine which branch to take based on the node's output
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
# For other branch nodes (e.g. HTTP), use output field
|
||||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
|
||||||
|
related_edge[idx]['condition'] = (
|
||||||
|
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
|
||||||
|
)
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
|
|||||||
@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
self._input_data: dict[str, Any] | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
|
||||||
|
if self._input_data is not None:
|
||||||
|
return self._input_data
|
||||||
|
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the assignment operation defined by this node.
|
Execute the assignment operation defined by this node.
|
||||||
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
None or the result of the assignment operation.
|
None or the result of the assignment operation.
|
||||||
"""
|
"""
|
||||||
|
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
|
||||||
|
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
|
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
self.typed_config = AssignerNodeConfig(**self.config)
|
self.typed_config = AssignerNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} 开始执行")
|
logger.info(f"节点 {self.node_id} 开始执行")
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class NodeType(StrEnum):
|
|||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import uuid
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||||
|
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMErrorHandleConfig(BaseModel):
|
||||||
|
"""LLM 异常处理配置"""
|
||||||
|
|
||||||
|
method: HttpErrorHandle = Field(
|
||||||
|
default=HttpErrorHandle.NONE,
|
||||||
|
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
|
||||||
|
)
|
||||||
|
|
||||||
|
output: str = Field(
|
||||||
|
default="",
|
||||||
|
description="LLM 异常时返回的默认输出文本(method=default 时生效)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeConfig(BaseNodeConfig):
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
"""LLM 节点配置
|
"""LLM 节点配置
|
||||||
|
|
||||||
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_handle: LLMErrorHandleConfig = Field(
|
||||||
|
default_factory=LLMErrorHandleConfig,
|
||||||
|
description="LLM 异常处理配置",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v):
|
def validate_input_mode(cls, v):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
|
||||||
|
|
||||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||||
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
|
|||||||
variable_pool: 变量池
|
variable_pool: 变量池
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
|
||||||
|
{"llm_result": None, "branch_signal": "ERROR"} on branch error
|
||||||
"""
|
"""
|
||||||
# self.typed_config = LLMNodeConfig(**self.config)
|
try:
|
||||||
llm = await self._prepare_llm(state, variable_pool, False)
|
# self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
llm = await self._prepare_llm(state, variable_pool, False)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(self.messages)
|
response = await llm.ainvoke(self.messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
content = self.process_model_output(response.content)
|
content = self.process_model_output(response.content)
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
# 返回 AIMessage(包含响应元数据)
|
# 返回 AIMessage(包含响应元数据)
|
||||||
return AIMessage(content=content, response_metadata={
|
return {
|
||||||
**response.response_metadata,
|
"llm_result": AIMessage(content=content, response_metadata={
|
||||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
**response.response_metadata,
|
||||||
})
|
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||||
|
}),
|
||||||
|
"branch_signal": "SUCCESS",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
|
||||||
|
return self._handle_llm_error(e)
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> str:
|
def _extract_output(self, business_result: Any) -> dict:
|
||||||
"""从 AIMessage 中提取文本内容"""
|
"""从业务结果中提取输出变量
|
||||||
|
|
||||||
|
支持新旧两种格式:
|
||||||
|
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
|
||||||
|
- 旧格式:AIMessage(向后兼容)
|
||||||
|
"""
|
||||||
|
if isinstance(business_result, dict) and "branch_signal" in business_result:
|
||||||
|
llm_result = business_result.get("llm_result")
|
||||||
|
if isinstance(llm_result, AIMessage):
|
||||||
|
return {
|
||||||
|
"output": llm_result.content,
|
||||||
|
"branch_signal": business_result["branch_signal"],
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"output": str(llm_result) if llm_result else "",
|
||||||
|
"branch_signal": business_result["branch_signal"],
|
||||||
|
}
|
||||||
|
# 旧格式向后兼容
|
||||||
if isinstance(business_result, AIMessage):
|
if isinstance(business_result, AIMessage):
|
||||||
return business_result.content
|
return {"output": business_result.content, "branch_signal": "SUCCESS"}
|
||||||
return str(business_result)
|
return {"output": str(business_result), "branch_signal": "SUCCESS"}
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从 AIMessage 中提取 token 使用情况"""
|
"""从业务结果中提取 token 使用情况"""
|
||||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
llm_result = business_result
|
||||||
usage = business_result.response_metadata.get('token_usage')
|
if isinstance(business_result, dict):
|
||||||
|
llm_result = business_result.get("llm_result", business_result)
|
||||||
|
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
|
||||||
|
usage = llm_result.response_metadata.get('token_usage')
|
||||||
if usage:
|
if usage:
|
||||||
return {
|
return {
|
||||||
"prompt_tokens": usage.get('input_tokens', 0),
|
"prompt_tokens": usage.get('input_tokens', 0),
|
||||||
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _handle_llm_error(self, error: Exception) -> dict:
|
||||||
|
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: LLM 调用中捕获的异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
|
||||||
|
or default output for default mode
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
原异常(当 error_handle.method 为 NONE 时)
|
||||||
|
"""
|
||||||
|
if self.typed_config is None:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
match self.typed_config.error_handle.method:
|
||||||
|
case HttpErrorHandle.NONE:
|
||||||
|
raise error
|
||||||
|
case HttpErrorHandle.DEFAULT:
|
||||||
|
logger.warning(
|
||||||
|
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
|
||||||
|
)
|
||||||
|
default_output = self.typed_config.error_handle.output or ""
|
||||||
|
return {
|
||||||
|
"llm_result": AIMessage(content=default_output, response_metadata={}),
|
||||||
|
"branch_signal": "SUCCESS",
|
||||||
|
}
|
||||||
|
case HttpErrorHandle.BRANCH:
|
||||||
|
logger.warning(
|
||||||
|
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"llm_result": None,
|
||||||
|
"branch_signal": "ERROR",
|
||||||
|
}
|
||||||
|
raise error
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
"""流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
llm = await self._prepare_llm(state, variable_pool, True)
|
try:
|
||||||
|
llm = await self._prepare_llm(state, variable_pool, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
last_meta_data = {}
|
last_meta_data = {}
|
||||||
last_usage_metadata = {}
|
last_usage_metadata = {}
|
||||||
async for chunk in llm.astream(self.messages):
|
async for chunk in llm.astream(self.messages):
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
content = self.process_model_output(chunk.content)
|
content = self.process_model_output(chunk.content)
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||||
last_meta_data = chunk.response_metadata
|
last_meta_data = chunk.response_metadata
|
||||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||||
last_usage_metadata = chunk.usage_metadata
|
last_usage_metadata = chunk.usage_metadata
|
||||||
|
|
||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield {
|
yield {
|
||||||
"__final__": False,
|
"__final__": False,
|
||||||
"chunk": content
|
"chunk": content
|
||||||
}
|
}
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"__final__": False,
|
"__final__": False,
|
||||||
"chunk": "",
|
"chunk": "",
|
||||||
"done": True
|
"done": True
|
||||||
}
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
|
||||||
final_message = AIMessage(
|
|
||||||
content=full_response,
|
|
||||||
response_metadata={
|
|
||||||
**last_meta_data,
|
|
||||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
|
||||||
}
|
}
|
||||||
)
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# yield 完成标记
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
yield {"__final__": True, "result": final_message}
|
final_message = AIMessage(
|
||||||
|
content=full_response,
|
||||||
|
response_metadata={
|
||||||
|
**last_meta_data,
|
||||||
|
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# yield 完成标记
|
||||||
|
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
|
||||||
|
error_result = self._handle_llm_error(e)
|
||||||
|
yield {"__final__": True, "result": error_result}
|
||||||
|
|||||||
Reference in New Issue
Block a user