Merge pull request #25 from SuanmoSuanyangTechnology/fix/workflow

Fix/workflow
This commit is contained in:
Mark
2026-01-06 12:01:42 +08:00
committed by GitHub
9 changed files with 108 additions and 90 deletions

View File

@@ -117,7 +117,7 @@ async def get_prompt_opt(
user_require=data.message user_require=data.message
): ):
# chunk 是 prompt 的增量内容 # chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),

View File

@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
# Set of loop node IDs, used for assigning values in loop nodes # Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list cycle_nodes: list
looping: bool looping: Annotated[bool, lambda x, y: x and y]
# Input variables (passed from configured variables) # Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)

View File

@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
retries -= 1 retries -= 1
if retries > 0: if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000) await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
raise e
except Exception as e:
raise RuntimeError(f"HTTP request node exception: {e}")
else: else:
match self.typed_config.error_handle.method: match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT: case HttpErrorHandle.DEFAULT:
logger.warning( logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result" f"Node {self.node_id}: HTTP request failed, returning default result"
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
f"Node {self.node_id}: HTTP request failed, switching to error handling branch" f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
) )
return "ERROR" return "ERROR"
raise RuntimeError("http request failed")

View File

@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=kb_config.similarity_threshold) score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results # Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
vector_service.reranker = self.get_reranker_model() vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
vector_service.reranker = self.get_reranker_model() vector_service.reranker = self.get_reranker_model()
# TODO其他重排序方式支持
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info( logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
) )
return [chunk.model_dump() for chunk in final_rs] return [chunk.page_content for chunk in final_rs]

View File

@@ -1,5 +1,7 @@
"""LLM 节点配置""" """LLM 节点配置"""
from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
@@ -41,6 +43,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="模型配置 ID" description="模型配置 ID"
) )
context: Any = Field(
default="",
description="上下文"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,

View File

@@ -5,11 +5,13 @@ LLM 节点实现
""" """
import logging import logging
import re
from typing import Any from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
- user/human: 用户消息HumanMessage - user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: def _render_context(self, message,state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑) """准备 LLM 实例(公共逻辑)
Args: Args:
@@ -83,6 +92,7 @@ class LLMNode(BaseNode):
for msg_config in messages_config: for msg_config in messages_config:
role = msg_config.get("role", "user").lower() role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "") content_template = msg_config.get("content", "")
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state) content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
@@ -153,7 +163,7 @@ class LLMNode(BaseNode):
Returns: Returns:
LLM 响应消息 LLM 响应消息
""" """
llm, prompt_or_messages = self._prepare_llm(state,True) llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
group_id=end_user_id, group_id=end_user_id,
message=self.typed_config.message, message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
return await MemoryAgentService().write_memory( return await MemoryAgentService().write_memory(
group_id=end_user_id, group_id=end_user_id,
message=self.typed_config.message, message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id, config_id=self.typed_config.config_id,
db=db, db=db,
storage_type="neo4j", storage_type="neo4j",

View File

@@ -87,10 +87,11 @@ class WorkflowValidator:
return graphs return graphs
@classmethod @classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
"""验证工作流配置 """验证工作流配置
Args: Args:
publish: 发布验证标识
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns: Returns:
@@ -114,7 +115,7 @@ class WorkflowValidator:
graphs = cls.get_subgraph(workflow_config) graphs = cls.get_subgraph(workflow_config)
logger.info(graphs) logger.info(graphs)
for graph in graphs: for index, graph in enumerate(graphs):
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
edges = graph.get("edges", []) edges = graph.get("edges", [])
variables = graph.get("variables", []) variables = graph.get("variables", [])
@@ -125,10 +126,11 @@ class WorkflowValidator:
elif len(start_nodes) > 1: elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}") errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个) if index == len(graphs) - 1:
end_nodes = [n for n in nodes if n.get("type") == NodeType.END] # 2. 验证 主图end 节点(至少一个)
if len(end_nodes) == 0: end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
errors.append("工作流必须至少有一个 end 节点") if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性 # 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes] node_ids = [n.get("id") for n in nodes]
@@ -159,15 +161,17 @@ class WorkflowValidator:
elif target not in node_id_set: elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}") errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发) if publish:
if start_nodes and not errors: # 只有在前面验证通过时才检查可达 # 仅在发布时验证所有节点可达
reachable = WorkflowValidator._get_reachable_nodes( # 6. 验证所有节点可达(从 start 节点出发)
start_nodes[0]["id"], if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
edges reachable = WorkflowValidator._get_reachable_nodes(
) start_nodes[0]["id"],
unreachable = node_id_set - reachable edges
if unreachable: )
errors.append(f"以下节点无法从 start 节点到达: {unreachable}") unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点) # 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环 if not errors: # 只有在前面验证通过时才检查循环
@@ -288,7 +292,7 @@ class WorkflowValidator:
(is_valid, errors): 是否有效和错误列表 (is_valid, errors): 是否有效和错误列表
""" """
# 先执行基础验证 # 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config) is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
if not is_valid: if not is_valid:
return False, errors return False, errors

View File

@@ -231,9 +231,9 @@ class PromptOptimizerService:
if m: if m:
prompt_index = m.start() prompt_index = m.start()
prompt_finished = True prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]} yield {"content": buffer[idx:prompt_index]}
else: else:
yield {"type": "delta", "content": cache[idx:]} yield {"content": cache[idx:]}
if len(cache) != 0: if len(cache) != 0:
idx = len(cache) idx = len(cache)
@@ -249,8 +249,8 @@ class PromptOptimizerService:
role=RoleType.ASSISTANT, role=RoleType.ASSISTANT,
content=desc content=desc
) )
variables = self.parser_prompt_variables(optim_result.get("prompt"))
yield {"type": "done", "desc": optim_result.get("desc")} yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod @staticmethod
def parser_prompt_variables(prompt: str): def parser_prompt_variables(prompt: str):