style(workflow): remove unnecessary indentation

This commit is contained in:
mengyonghao
2025-12-22 16:18:25 +08:00
parent c15a987701
commit 75ee591202
6 changed files with 57 additions and 70 deletions

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
base_url=api_base, base_url=api_base,
extra_params=extra_params extra_params=extra_params
), ),
type=model_type type=ModelType(model_type)
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")

View File

@@ -7,6 +7,7 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.assigner import AssignerNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +28,8 @@ WorkflowNode = Union[
IfElseNode, IfElseNode,
AgentNode, AgentNode,
TransformNode, TransformNode,
AssignerNode,
KnowledgeRetrievalNode,
] ]
@@ -42,7 +46,9 @@ class NodeFactory:
NodeType.LLM: LLMNode, NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode, NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode, NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
} }
@classmethod @classmethod
@@ -82,10 +88,6 @@ class NodeFactory:
""" """
node_type = node_config.get("type") node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类 # 获取节点类
node_class = cls._node_types.get(node_type) node_class = cls._node_types.get(node_type)
if not node_class: if not node_class:

View File

@@ -10,7 +10,10 @@
""" """
import logging import logging
from typing import Any from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,7 +85,7 @@ class VariablePool:
>>> pool.set(["conv", "user_name"], "张三") >>> pool.set(["conv", "user_name"], "张三")
""" """
def __init__(self, state: dict[str, Any]): def __init__(self, state: "WorkflowState"):
"""初始化变量池 """初始化变量池
Args: Args:

View File

@@ -15,25 +15,6 @@ class ModelType(StrEnum):
EMBEDDING = "embedding" EMBEDDING = "embedding"
RERANK = "rerank" RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
"""模型提供商枚举""" """模型提供商枚举"""

View File

@@ -169,7 +169,7 @@ class PromptOptimizerService:
provider=api_config.provider, provider=api_config.provider,
api_key=api_config.api_key, api_key=api_config.api_key,
base_url=api_config.api_base base_url=api_config.api_base
), type=ModelType.from_str(model_config.type)) ), type=ModelType(model_config.type))
# build message # build message
messages = [ messages = [