Merge pull request #25 from SuanmoSuanyangTechnology/fix/workflow
Fix/workflow
This commit is contained in:
@@ -117,7 +117,7 @@ async def get_prompt_opt(
|
|||||||
user_require=data.message
|
user_require=data.message
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
|||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
looping: bool
|
looping: Annotated[bool, lambda x, y: x and y]
|
||||||
|
|
||||||
# Input variables (passed from configured variables)
|
# Input variables (passed from configured variables)
|
||||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||||
|
|||||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
|||||||
retries -= 1
|
retries -= 1
|
||||||
if retries > 0:
|
if retries > 0:
|
||||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||||
|
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||||
else:
|
else:
|
||||||
match self.typed_config.error_handle.method:
|
match self.typed_config.error_handle.method:
|
||||||
case HttpErrorHandle.NONE:
|
|
||||||
logger.warning(
|
|
||||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
|
||||||
)
|
|
||||||
return HttpRequestNodeOutput(
|
|
||||||
body="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
headers=resp.headers,
|
|
||||||
).model_dump()
|
|
||||||
case HttpErrorHandle.DEFAULT:
|
case HttpErrorHandle.DEFAULT:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
|||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return "ERROR"
|
return "ERROR"
|
||||||
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.similarity_threshold)
|
score_threshold=kb_config.similarity_threshold)
|
||||||
# Deduplicate hybrid retrieval results
|
# Deduplicate hy brid retrieval results
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
|
# TODO:其他重排序方式支持
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
)
|
)
|
||||||
return [chunk.model_dump() for chunk in final_rs]
|
return [chunk.page_content for chunk in final_rs]
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""LLM 节点配置"""
|
"""LLM 节点配置"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
@@ -41,6 +43,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="模型配置 ID"
|
description="模型配置 ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context: Any = Field(
|
||||||
|
default="",
|
||||||
|
description="上下文"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ LLM 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
|
|||||||
- user/human: 用户消息(HumanMessage)
|
- user/human: 用户消息(HumanMessage)
|
||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
def _render_context(self, message,state):
|
||||||
|
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||||
|
return re.sub(r"{{context}}", context, message)
|
||||||
|
|
||||||
|
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||||
"""准备 LLM 实例(公共逻辑)
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -83,6 +92,7 @@ class LLMNode(BaseNode):
|
|||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.get("role", "user").lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.get("content", "")
|
||||||
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
@@ -153,7 +163,7 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().read_memory(
|
return await MemoryAgentService().read_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().write_memory(
|
return await MemoryAgentService().write_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type="neo4j",
|
||||||
|
|||||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
|||||||
return graphs
|
return graphs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
publish: 发布验证标识
|
||||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
|||||||
|
|
||||||
graphs = cls.get_subgraph(workflow_config)
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
logger.info(graphs)
|
logger.info(graphs)
|
||||||
for graph in graphs:
|
for index, graph in enumerate(graphs):
|
||||||
nodes = graph.get("nodes", [])
|
nodes = graph.get("nodes", [])
|
||||||
edges = graph.get("edges", [])
|
edges = graph.get("edges", [])
|
||||||
variables = graph.get("variables", [])
|
variables = graph.get("variables", [])
|
||||||
@@ -125,10 +126,11 @@ class WorkflowValidator:
|
|||||||
elif len(start_nodes) > 1:
|
elif len(start_nodes) > 1:
|
||||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
|
|
||||||
# 2. 验证 end 节点(至少一个)
|
if index == len(graphs) - 1:
|
||||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
# 2. 验证 主图end 节点(至少一个)
|
||||||
if len(end_nodes) == 0:
|
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||||
errors.append("工作流必须至少有一个 end 节点")
|
if len(end_nodes) == 0:
|
||||||
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
# 3. 验证节点 ID 唯一性
|
# 3. 验证节点 ID 唯一性
|
||||||
node_ids = [n.get("id") for n in nodes]
|
node_ids = [n.get("id") for n in nodes]
|
||||||
@@ -159,15 +161,17 @@ class WorkflowValidator:
|
|||||||
elif target not in node_id_set:
|
elif target not in node_id_set:
|
||||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||||
|
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
if publish:
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
# 仅在发布时验证所有节点可达
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
start_nodes[0]["id"],
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
edges
|
reachable = WorkflowValidator._get_reachable_nodes(
|
||||||
)
|
start_nodes[0]["id"],
|
||||||
unreachable = node_id_set - reachable
|
edges
|
||||||
if unreachable:
|
)
|
||||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
unreachable = node_id_set - reachable
|
||||||
|
if unreachable:
|
||||||
|
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||||
|
|
||||||
# 7. 检测循环依赖(非 loop 节点)
|
# 7. 检测循环依赖(非 loop 节点)
|
||||||
if not errors: # 只有在前面验证通过时才检查循环
|
if not errors: # 只有在前面验证通过时才检查循环
|
||||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
|||||||
(is_valid, errors): 是否有效和错误列表
|
(is_valid, errors): 是否有效和错误列表
|
||||||
"""
|
"""
|
||||||
# 先执行基础验证
|
# 先执行基础验证
|
||||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return False, errors
|
return False, errors
|
||||||
|
|||||||
@@ -231,9 +231,9 @@ class PromptOptimizerService:
|
|||||||
if m:
|
if m:
|
||||||
prompt_index = m.start()
|
prompt_index = m.start()
|
||||||
prompt_finished = True
|
prompt_finished = True
|
||||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
yield {"content": buffer[idx:prompt_index]}
|
||||||
else:
|
else:
|
||||||
yield {"type": "delta", "content": cache[idx:]}
|
yield {"content": cache[idx:]}
|
||||||
if len(cache) != 0:
|
if len(cache) != 0:
|
||||||
idx = len(cache)
|
idx = len(cache)
|
||||||
|
|
||||||
@@ -249,8 +249,8 @@ class PromptOptimizerService:
|
|||||||
role=RoleType.ASSISTANT,
|
role=RoleType.ASSISTANT,
|
||||||
content=desc
|
content=desc
|
||||||
)
|
)
|
||||||
|
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user