feat(workflow): support context injection in LLM node
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
1. 简单模式:使用 prompt 字段
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
|
||||
context: Any = Field(
|
||||
default="",
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
description="提示词模板(简单模式),支持变量引用"
|
||||
)
|
||||
|
||||
|
||||
# 消息模式(推荐)
|
||||
messages: list[MessageConfig] | None = Field(
|
||||
default=None,
|
||||
description="消息列表(消息模式),支持多轮对话"
|
||||
)
|
||||
|
||||
|
||||
# 模型参数
|
||||
temperature: float | None = Field(
|
||||
default=0.7,
|
||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
le=2.0,
|
||||
description="温度参数,控制输出的随机性"
|
||||
)
|
||||
|
||||
|
||||
max_tokens: int | None = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=32000,
|
||||
description="最大生成 token 数"
|
||||
)
|
||||
|
||||
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="频率惩罚"
|
||||
)
|
||||
|
||||
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="存在惩罚"
|
||||
)
|
||||
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
|
||||
Reference in New Issue
Block a user