feat(workflow): support context injection in LLM node
This commit is contained in:
@@ -5,11 +5,13 @@ LLM 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
|
|||||||
- user/human: 用户消息(HumanMessage)
|
- user/human: 用户消息(HumanMessage)
|
||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
def _render_context(self, message,state):
|
||||||
|
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||||
|
return re.sub(r"{{context}}", context, message)
|
||||||
|
|
||||||
|
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||||
"""准备 LLM 实例(公共逻辑)
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -83,6 +92,7 @@ class LLMNode(BaseNode):
|
|||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.get("role", "user").lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.get("content", "")
|
||||||
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
@@ -153,7 +163,7 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user