feat(workflow): support context injection in LLM node

This commit is contained in:
mengyonghao
2026-01-05 17:37:45 +08:00
parent d4a87187cb
commit e1e77f70f9

View File

@@ -1,5 +1,7 @@
"""LLM 节点配置""" """LLM 节点配置"""
from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
class MessageConfig(BaseModel): class MessageConfig(BaseModel):
"""消息配置""" """消息配置"""
role: str = Field( role: str = Field(
..., ...,
description="消息角色system, user, assistant" description="消息角色system, user, assistant"
) )
content: str = Field( content: str = Field(
..., ...,
description="消息内容,支持模板变量,如:{{ sys.message }}" description="消息内容,支持模板变量,如:{{ sys.message }}"
) )
@field_validator("role") @field_validator("role")
@classmethod @classmethod
def validate_role(cls, v: str) -> str: def validate_role(cls, v: str) -> str:
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
1. 简单模式:使用 prompt 字段 1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐) 2. 消息模式:使用 messages 字段(推荐)
""" """
model_id: str = Field( model_id: str = Field(
..., ...,
description="模型配置 ID" description="模型配置 ID"
) )
context: Any = Field(
default="",
description="上下文"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,
description="提示词模板(简单模式),支持变量引用" description="提示词模板(简单模式),支持变量引用"
) )
# 消息模式(推荐) # 消息模式(推荐)
messages: list[MessageConfig] | None = Field( messages: list[MessageConfig] | None = Field(
default=None, default=None,
description="消息列表(消息模式),支持多轮对话" description="消息列表(消息模式),支持多轮对话"
) )
# 模型参数 # 模型参数
temperature: float | None = Field( temperature: float | None = Field(
default=0.7, default=0.7,
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
le=2.0, le=2.0,
description="温度参数,控制输出的随机性" description="温度参数,控制输出的随机性"
) )
max_tokens: int | None = Field( max_tokens: int | None = Field(
default=1000, default=1000,
ge=1, ge=1,
le=32000, le=32000,
description="最大生成 token 数" description="最大生成 token 数"
) )
top_p: float | None = Field( top_p: float | None = Field(
default=None, default=None,
ge=0.0, ge=0.0,
le=1.0, le=1.0,
description="Top-p 采样参数" description="Top-p 采样参数"
) )
frequency_penalty: float | None = Field( frequency_penalty: float | None = Field(
default=None, default=None,
ge=-2.0, ge=-2.0,
le=2.0, le=2.0,
description="频率惩罚" description="频率惩罚"
) )
presence_penalty: float | None = Field( presence_penalty: float | None = Field(
default=None, default=None,
ge=-2.0, ge=-2.0,
le=2.0, le=2.0,
description="存在惩罚" description="存在惩罚"
) )
# 输出变量定义 # 输出变量定义
output_variables: list[VariableDefinition] = Field( output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [ default_factory=lambda: [
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
], ],
description="输出变量定义(自动生成,通常不需要修改)" description="输出变量定义(自动生成,通常不需要修改)"
) )
@field_validator("messages", "prompt") @field_validator("messages", "prompt")
@classmethod @classmethod
def validate_input_mode(cls, v, info): def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个""" """验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适 # 这个验证在 model_validator 中更合适
return v return v
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"examples": [ "examples": [