Merge pull request #25 from SuanmoSuanyangTechnology/fix/workflow
Fix/workflow
This commit is contained in:
@@ -117,7 +117,7 @@ async def get_prompt_opt(
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: bool
|
||||
looping: Annotated[bool, lambda x, y: x and y]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
|
||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
||||
retries -= 1
|
||||
if retries > 0:
|
||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -210,8 +210,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
@@ -41,6 +43,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
context: Any = Field(
|
||||
default="",
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -5,11 +5,13 @@ LLM 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -63,6 +65,13 @@ class LLMNode(BaseNode):
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
def _render_context(self, message,state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
@@ -83,6 +92,7 @@ class LLMNode(BaseNode):
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
|
||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
|
||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
publish: 发布验证标识
|
||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||
|
||||
Returns:
|
||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
for index, graph in enumerate(graphs):
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
@@ -125,7 +126,8 @@ class WorkflowValidator:
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
if index == len(graphs) - 1:
|
||||
# 2. 验证 主图end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
@@ -159,6 +161,8 @@ class WorkflowValidator:
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
if publish:
|
||||
# 仅在发布时验证所有节点可达
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
@@ -231,9 +231,9 @@ class PromptOptimizerService:
|
||||
if m:
|
||||
prompt_index = m.start()
|
||||
prompt_finished = True
|
||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||
yield {"content": buffer[idx:prompt_index]}
|
||||
else:
|
||||
yield {"type": "delta", "content": cache[idx:]}
|
||||
yield {"content": cache[idx:]}
|
||||
if len(cache) != 0:
|
||||
idx = len(cache)
|
||||
|
||||
@@ -249,8 +249,8 @@ class PromptOptimizerService:
|
||||
role=RoleType.ASSISTANT,
|
||||
content=desc
|
||||
)
|
||||
|
||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||
|
||||
@staticmethod
|
||||
def parser_prompt_variables(prompt: str):
|
||||
|
||||
Reference in New Issue
Block a user