feat(workflow): support context injection in LLM node
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
"""LLM 节点配置"""
|
"""LLM 节点配置"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
|||||||
|
|
||||||
class MessageConfig(BaseModel):
|
class MessageConfig(BaseModel):
|
||||||
"""消息配置"""
|
"""消息配置"""
|
||||||
|
|
||||||
role: str = Field(
|
role: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息角色:system, user, assistant"
|
description="消息角色:system, user, assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
content: str = Field(
|
content: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("role")
|
@field_validator("role")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_role(cls, v: str) -> str:
|
def validate_role(cls, v: str) -> str:
|
||||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
1. 简单模式:使用 prompt 字段
|
1. 简单模式:使用 prompt 字段
|
||||||
2. 消息模式:使用 messages 字段(推荐)
|
2. 消息模式:使用 messages 字段(推荐)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str = Field(
|
model_id: str = Field(
|
||||||
...,
|
...,
|
||||||
description="模型配置 ID"
|
description="模型配置 ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context: Any = Field(
|
||||||
|
default="",
|
||||||
|
description="上下文"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="提示词模板(简单模式),支持变量引用"
|
description="提示词模板(简单模式),支持变量引用"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 消息模式(推荐)
|
# 消息模式(推荐)
|
||||||
messages: list[MessageConfig] | None = Field(
|
messages: list[MessageConfig] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="消息列表(消息模式),支持多轮对话"
|
description="消息列表(消息模式),支持多轮对话"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 模型参数
|
# 模型参数
|
||||||
temperature: float | None = Field(
|
temperature: float | None = Field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
le=2.0,
|
le=2.0,
|
||||||
description="温度参数,控制输出的随机性"
|
description="温度参数,控制输出的随机性"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens: int | None = Field(
|
max_tokens: int | None = Field(
|
||||||
default=1000,
|
default=1000,
|
||||||
ge=1,
|
ge=1,
|
||||||
le=32000,
|
le=32000,
|
||||||
description="最大生成 token 数"
|
description="最大生成 token 数"
|
||||||
)
|
)
|
||||||
|
|
||||||
top_p: float | None = Field(
|
top_p: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
description="Top-p 采样参数"
|
description="Top-p 采样参数"
|
||||||
)
|
)
|
||||||
|
|
||||||
frequency_penalty: float | None = Field(
|
frequency_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="频率惩罚"
|
description="频率惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
presence_penalty: float | None = Field(
|
presence_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="存在惩罚"
|
description="存在惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出变量定义
|
# 输出变量定义
|
||||||
output_variables: list[VariableDefinition] = Field(
|
output_variables: list[VariableDefinition] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
],
|
],
|
||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v, info):
|
def validate_input_mode(cls, v, info):
|
||||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||||
# 这个验证在 model_validator 中更合适
|
# 这个验证在 model_validator 中更合适
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
"examples": [
|
"examples": [
|
||||||
|
|||||||
Reference in New Issue
Block a user