feat(workflow): support context injection in LLM node

This commit is contained in:
mengyonghao
2026-01-05 17:17:52 +08:00
parent 78207aca34
commit 35db38c2de

View File

@@ -5,11 +5,13 @@ LLM 节点实现
""" """
import logging import logging
import re
from typing import Any from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
- user/human: 用户消息HumanMessage - user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: def _render_context(self, message,state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑) """准备 LLM 实例(公共逻辑)
Args: Args:
@@ -83,6 +92,7 @@ class LLMNode(BaseNode):
for msg_config in messages_config: for msg_config in messages_config:
role = msg_config.get("role", "user").lower() role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "") content_template = msg_config.get("content", "")
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state) content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
@@ -153,7 +163,7 @@ class LLMNode(BaseNode):
Returns: Returns:
LLM 响应消息 LLM 响应消息
""" """
llm, prompt_or_messages = self._prepare_llm(state,True) llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")