Merge pull request #25 from SuanmoSuanyangTechnology/fix/workflow
Fix/workflow
This commit is contained in:
@@ -117,7 +117,7 @@ async def get_prompt_opt(
|
|||||||
user_require=data.message
|
user_require=data.message
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
|||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
looping: bool
|
looping: Annotated[bool, lambda x, y: x and y]
|
||||||
|
|
||||||
# Input variables (passed from configured variables)
|
# Input variables (passed from configured variables)
|
||||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||||
|
|||||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
|||||||
retries -= 1
|
retries -= 1
|
||||||
if retries > 0:
|
if retries > 0:
|
||||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||||
|
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||||
else:
|
else:
|
||||||
match self.typed_config.error_handle.method:
|
match self.typed_config.error_handle.method:
|
||||||
case HttpErrorHandle.NONE:
|
|
||||||
logger.warning(
|
|
||||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
|
||||||
)
|
|
||||||
return HttpRequestNodeOutput(
|
|
||||||
body="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
headers=resp.headers,
|
|
||||||
).model_dump()
|
|
||||||
case HttpErrorHandle.DEFAULT:
|
case HttpErrorHandle.DEFAULT:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
|||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return "ERROR"
|
return "ERROR"
|
||||||
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.similarity_threshold)
|
score_threshold=kb_config.similarity_threshold)
|
||||||
# Deduplicate hybrid retrieval results
|
# Deduplicate hy brid retrieval results
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
|
# TODO:其他重排序方式支持
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
)
|
)
|
||||||
return [chunk.model_dump() for chunk in final_rs]
|
return [chunk.page_content for chunk in final_rs]
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""LLM 节点配置"""
|
"""LLM 节点配置"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
|||||||
|
|
||||||
class MessageConfig(BaseModel):
|
class MessageConfig(BaseModel):
|
||||||
"""消息配置"""
|
"""消息配置"""
|
||||||
|
|
||||||
role: str = Field(
|
role: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息角色:system, user, assistant"
|
description="消息角色:system, user, assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
content: str = Field(
|
content: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("role")
|
@field_validator("role")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_role(cls, v: str) -> str:
|
def validate_role(cls, v: str) -> str:
|
||||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
1. 简单模式:使用 prompt 字段
|
1. 简单模式:使用 prompt 字段
|
||||||
2. 消息模式:使用 messages 字段(推荐)
|
2. 消息模式:使用 messages 字段(推荐)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str = Field(
|
model_id: str = Field(
|
||||||
...,
|
...,
|
||||||
description="模型配置 ID"
|
description="模型配置 ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context: Any = Field(
|
||||||
|
default="",
|
||||||
|
description="上下文"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="提示词模板(简单模式),支持变量引用"
|
description="提示词模板(简单模式),支持变量引用"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 消息模式(推荐)
|
# 消息模式(推荐)
|
||||||
messages: list[MessageConfig] | None = Field(
|
messages: list[MessageConfig] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="消息列表(消息模式),支持多轮对话"
|
description="消息列表(消息模式),支持多轮对话"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 模型参数
|
# 模型参数
|
||||||
temperature: float | None = Field(
|
temperature: float | None = Field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
le=2.0,
|
le=2.0,
|
||||||
description="温度参数,控制输出的随机性"
|
description="温度参数,控制输出的随机性"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens: int | None = Field(
|
max_tokens: int | None = Field(
|
||||||
default=1000,
|
default=1000,
|
||||||
ge=1,
|
ge=1,
|
||||||
le=32000,
|
le=32000,
|
||||||
description="最大生成 token 数"
|
description="最大生成 token 数"
|
||||||
)
|
)
|
||||||
|
|
||||||
top_p: float | None = Field(
|
top_p: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
description="Top-p 采样参数"
|
description="Top-p 采样参数"
|
||||||
)
|
)
|
||||||
|
|
||||||
frequency_penalty: float | None = Field(
|
frequency_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="频率惩罚"
|
description="频率惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
presence_penalty: float | None = Field(
|
presence_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="存在惩罚"
|
description="存在惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出变量定义
|
# 输出变量定义
|
||||||
output_variables: list[VariableDefinition] = Field(
|
output_variables: list[VariableDefinition] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
],
|
],
|
||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v, info):
|
def validate_input_mode(cls, v, info):
|
||||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||||
# 这个验证在 model_validator 中更合适
|
# 这个验证在 model_validator 中更合适
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
"examples": [
|
"examples": [
|
||||||
|
|||||||
@@ -5,15 +5,17 @@ LLM 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
|
|||||||
- user/human: 用户消息(HumanMessage)
|
- user/human: 用户消息(HumanMessage)
|
||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
|
def _render_context(self, message,state):
|
||||||
|
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||||
|
return re.sub(r"{{context}}", context, message)
|
||||||
|
|
||||||
|
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||||
"""准备 LLM 实例(公共逻辑)
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -76,15 +85,16 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 1. 处理消息格式(优先使用 messages)
|
# 1. 处理消息格式(优先使用 messages)
|
||||||
messages_config = self.config.get("messages")
|
messages_config = self.config.get("messages")
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.get("role", "user").lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.get("content", "")
|
||||||
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append(SystemMessage(content=content))
|
messages.append(SystemMessage(content=content))
|
||||||
@@ -95,7 +105,7 @@ class LLMNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append(HumanMessage(content=content))
|
||||||
|
|
||||||
prompt_or_messages = messages
|
prompt_or_messages = messages
|
||||||
else:
|
else:
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
@@ -106,17 +116,17 @@ class LLMNode(BaseNode):
|
|||||||
model_id = self.config.get("model_id")
|
model_id = self.config.get("model_id")
|
||||||
if not model_id:
|
if not model_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||||
|
|
||||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
# 在 Session 关闭前提取所有需要的数据
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
api_config = config.api_keys[0]
|
api_config = config.api_keys[0]
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
@@ -124,26 +134,26 @@ class LLMNode(BaseNode):
|
|||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||||
extra_params = {"streaming": stream} if stream else {}
|
extra_params = {"streaming": stream} if stream else {}
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
extra_params=extra_params
|
extra_params=extra_params
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||||
|
|
||||||
return llm, prompt_or_messages
|
return llm, prompt_or_messages
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -153,10 +163,10 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(prompt_or_messages)
|
response = await llm.ainvoke(prompt_or_messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
@@ -164,16 +174,16 @@ class LLMNode(BaseNode):
|
|||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
# 返回 AIMessage(包含响应元数据)
|
# 返回 AIMessage(包含响应元数据)
|
||||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
_, prompt_or_messages = self._prepare_llm(state)
|
_, prompt_or_messages = self._prepare_llm(state)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -186,13 +196,13 @@ class LLMNode(BaseNode):
|
|||||||
"max_tokens": self.config.get("max_tokens")
|
"max_tokens": self.config.get("max_tokens")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> str:
|
def _extract_output(self, business_result: Any) -> str:
|
||||||
"""从 AIMessage 中提取文本内容"""
|
"""从 AIMessage 中提取文本内容"""
|
||||||
if isinstance(business_result, AIMessage):
|
if isinstance(business_result, AIMessage):
|
||||||
return business_result.content
|
return business_result.content
|
||||||
return str(business_result)
|
return str(business_result)
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从 AIMessage 中提取 token 使用情况"""
|
"""从 AIMessage 中提取 token 使用情况"""
|
||||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||||
@@ -204,7 +214,7 @@ class LLMNode(BaseNode):
|
|||||||
"total_tokens": usage.get('total_tokens', 0)
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
"""流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -215,26 +225,26 @@ class LLMNode(BaseNode):
|
|||||||
文本片段(chunk)或完成标记
|
文本片段(chunk)或完成标记
|
||||||
"""
|
"""
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
# 检查是否有注入的 End 节点前缀配置
|
# 检查是否有注入的 End 节点前缀配置
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||||
|
|
||||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||||
if end_prefix:
|
if end_prefix:
|
||||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||||
|
|
||||||
if end_prefix:
|
if end_prefix:
|
||||||
# 渲染前缀(可能包含其他变量)
|
# 渲染前缀(可能包含其他变量)
|
||||||
try:
|
try:
|
||||||
rendered_prefix = self._render_template(end_prefix, state)
|
rendered_prefix = self._render_template(end_prefix, state)
|
||||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||||
|
|
||||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||||
writer({
|
writer({
|
||||||
"type": "message", # End 相关的内容都是 message 类型
|
"type": "message", # End 相关的内容都是 message 类型
|
||||||
@@ -246,12 +256,12 @@ class LLMNode(BaseNode):
|
|||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
last_chunk = None
|
last_chunk = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
async for chunk in llm.astream(prompt_or_messages):
|
async for chunk in llm.astream(prompt_or_messages):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
@@ -259,18 +269,18 @@ class LLMNode(BaseNode):
|
|||||||
content = chunk.content
|
content = chunk.content
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
|
|
||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
last_chunk = chunk
|
last_chunk = chunk
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
if isinstance(last_chunk, AIMessage):
|
if isinstance(last_chunk, AIMessage):
|
||||||
final_message = AIMessage(
|
final_message = AIMessage(
|
||||||
@@ -279,6 +289,6 @@ class LLMNode(BaseNode):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_message = AIMessage(content=full_response)
|
final_message = AIMessage(content=full_response)
|
||||||
|
|
||||||
# yield 完成标记
|
# yield 完成标记
|
||||||
yield {"__final__": True, "result": final_message}
|
yield {"__final__": True, "result": final_message}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().read_memory(
|
return await MemoryAgentService().read_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().write_memory(
|
return await MemoryAgentService().write_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type="neo4j",
|
||||||
|
|||||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
|||||||
return graphs
|
return graphs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
publish: 发布验证标识
|
||||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
|||||||
|
|
||||||
graphs = cls.get_subgraph(workflow_config)
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
logger.info(graphs)
|
logger.info(graphs)
|
||||||
for graph in graphs:
|
for index, graph in enumerate(graphs):
|
||||||
nodes = graph.get("nodes", [])
|
nodes = graph.get("nodes", [])
|
||||||
edges = graph.get("edges", [])
|
edges = graph.get("edges", [])
|
||||||
variables = graph.get("variables", [])
|
variables = graph.get("variables", [])
|
||||||
@@ -125,10 +126,11 @@ class WorkflowValidator:
|
|||||||
elif len(start_nodes) > 1:
|
elif len(start_nodes) > 1:
|
||||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
|
|
||||||
# 2. 验证 end 节点(至少一个)
|
if index == len(graphs) - 1:
|
||||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
# 2. 验证 主图end 节点(至少一个)
|
||||||
if len(end_nodes) == 0:
|
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||||
errors.append("工作流必须至少有一个 end 节点")
|
if len(end_nodes) == 0:
|
||||||
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
# 3. 验证节点 ID 唯一性
|
# 3. 验证节点 ID 唯一性
|
||||||
node_ids = [n.get("id") for n in nodes]
|
node_ids = [n.get("id") for n in nodes]
|
||||||
@@ -159,15 +161,17 @@ class WorkflowValidator:
|
|||||||
elif target not in node_id_set:
|
elif target not in node_id_set:
|
||||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||||
|
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
if publish:
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
# 仅在发布时验证所有节点可达
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
start_nodes[0]["id"],
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
edges
|
reachable = WorkflowValidator._get_reachable_nodes(
|
||||||
)
|
start_nodes[0]["id"],
|
||||||
unreachable = node_id_set - reachable
|
edges
|
||||||
if unreachable:
|
)
|
||||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
unreachable = node_id_set - reachable
|
||||||
|
if unreachable:
|
||||||
|
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||||
|
|
||||||
# 7. 检测循环依赖(非 loop 节点)
|
# 7. 检测循环依赖(非 loop 节点)
|
||||||
if not errors: # 只有在前面验证通过时才检查循环
|
if not errors: # 只有在前面验证通过时才检查循环
|
||||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
|||||||
(is_valid, errors): 是否有效和错误列表
|
(is_valid, errors): 是否有效和错误列表
|
||||||
"""
|
"""
|
||||||
# 先执行基础验证
|
# 先执行基础验证
|
||||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return False, errors
|
return False, errors
|
||||||
|
|||||||
@@ -231,9 +231,9 @@ class PromptOptimizerService:
|
|||||||
if m:
|
if m:
|
||||||
prompt_index = m.start()
|
prompt_index = m.start()
|
||||||
prompt_finished = True
|
prompt_finished = True
|
||||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
yield {"content": buffer[idx:prompt_index]}
|
||||||
else:
|
else:
|
||||||
yield {"type": "delta", "content": cache[idx:]}
|
yield {"content": cache[idx:]}
|
||||||
if len(cache) != 0:
|
if len(cache) != 0:
|
||||||
idx = len(cache)
|
idx = len(cache)
|
||||||
|
|
||||||
@@ -249,8 +249,8 @@ class PromptOptimizerService:
|
|||||||
role=RoleType.ASSISTANT,
|
role=RoleType.ASSISTANT,
|
||||||
content=desc
|
content=desc
|
||||||
)
|
)
|
||||||
|
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user