style(workflow): enforce PEP8 style and remove redundant imports

This commit is contained in:
Eternity
2026-01-27 17:51:27 +08:00
parent 166d05afe9
commit 2a10e9f7ee
11 changed files with 25 additions and 28 deletions

View File

@@ -1,3 +1,3 @@
from app.core.workflow.nodes.code.node import CodeNode from app.core.workflow.nodes.code.node import CodeNode
__all__ = ["CodeNode"] __all__ = ["CodeNode"]

View File

@@ -7,7 +7,6 @@ from textwrap import dedent
from typing import Any from typing import Any
import httpx import httpx
from sympy.physics.vector import vlatex
from app.core.workflow.nodes import BaseNode, WorkflowState from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.base_config import VariableType

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import copy
import logging import logging
import re import re
from typing import Any from typing import Any

View File

@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType

View File

@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None= None self.typed_config: IfElseNodeConfig | None = None
@staticmethod @staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any: def _evaluate(operator, instance: CompareOperatorInstance) -> Any:

View File

@@ -1,8 +1,6 @@
import uuid
from uuid import UUID from uuid import UUID
from pydantic import Field from pydantic import Field
from typing import Literal
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
end_user_id=end_user_id, end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state), message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id), config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,

View File

@@ -5,6 +5,7 @@ from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
class ClassifierConfig(BaseModel): class ClassifierConfig(BaseModel):
"""分类器节点配置""" """分类器节点配置"""
@@ -13,7 +14,7 @@ class ClassifierConfig(BaseModel):
class QuestionClassifierNodeConfig(BaseNodeConfig): class QuestionClassifierNodeConfig(BaseNodeConfig):
"""问题分类器节点配置""" """问题分类器节点配置"""
model_id: uuid.UUID = Field(..., description="LLM模型ID") model_id: uuid.UUID = Field(..., description="LLM模型ID")
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)") input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令") user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")

View File

@@ -18,30 +18,30 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
def _get_llm_instance(self) -> RedBearLLM: def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例""" """获取LLM实例"""
with get_db_read() as db: with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config: if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0: if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0] api_config = config.api_keys[0]
model_name = api_config.model_name model_name = api_config.model_name
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
base_url = api_config.api_base base_url = api_config.api_base
model_type = config.type model_type = config.type
return RedBearLLM( return RedBearLLM(
RedBearModelConfig( RedBearModelConfig(
model_name=model_name, model_name=model_name,
@@ -64,7 +64,7 @@ class QuestionClassifierNode(BaseNode):
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}" case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
category_map[category_name] = case_tag category_map[category_name] = case_tag
return category_map return category_map
async def execute(self, state: WorkflowState) -> dict: async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类""" """执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config) self.typed_config = QuestionClassifierNodeConfig(**self.config)
@@ -74,11 +74,12 @@ class QuestionClassifierNode(BaseNode):
categories = self.typed_config.categories or [] categories = self.typed_config.categories or []
category_names = [class_item.class_name.strip() for class_item in categories] category_names = [class_item.class_name.strip() for class_item in categories]
category_count = len(category_names) category_count = len(category_names)
if not question: if not question:
logger.warning( logger.warning(
f"节点 {self.node_id} 未获取到输入问题,使用默认分支" f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
f"默认分支{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}" f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE}"
f"分类总数: {category_count})"
) )
# 若分类列表为空返回默认unknown分支否则返回CASE1 # 若分类列表为空返回默认unknown分支否则返回CASE1
if category_count > 0: if category_count > 0:

View File

@@ -1,4 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"] __all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -16,11 +16,11 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class ToolNode(BaseNode): class ToolNode(BaseNode):
"""工具节点""" """工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: ToolNodeConfig | None = None self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]: async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具""" """执行工具"""
self.typed_config = ToolNodeConfig(**self.config) self.typed_config = ToolNodeConfig(**self.config)
@@ -28,21 +28,21 @@ class ToolNode(BaseNode):
tenant_id = self.get_variable("sys.tenant_id", state) tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state) user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable("sys.workspace_id", state) workspace_id = self.get_variable("sys.workspace_id", state)
# 如果没有租户ID尝试从工作流ID获取 # 如果没有租户ID尝试从工作流ID获取
if not tenant_id: if not tenant_id:
if workspace_id: if workspace_id:
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
with get_db_read() as db: with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
if not tenant_id: if not tenant_id:
logger.error(f"节点 {self.node_id} 缺少租户ID") logger.error(f"节点 {self.node_id} 缺少租户ID")
return { return {
"success": False, "success": False,
"data": "缺少租户ID" "data": "缺少租户ID"
} }
# 渲染工具参数 # 渲染工具参数
rendered_parameters = {} rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items(): for param_name, param_template in self.typed_config.tool_parameters.items():
@@ -55,9 +55,9 @@ class ToolNode(BaseNode):
# 非模板参数(数字/布尔/普通字符串)直接保留原值 # 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template rendered_value = param_template
rendered_parameters[param_name] = rendered_value rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
# 执行工具 # 执行工具
with get_db_read() as db: with get_db_read() as db:
tool_service = ToolService(db) tool_service = ToolService(db)
@@ -79,7 +79,7 @@ class ToolNode(BaseNode):
else: else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return { return {
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False), "data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
"error_code": result.error_code, "error_code": result.error_code,
"execution_time": result.execution_time "execution_time": result.execution_time
} }