style(workflow): enforce PEP8 style and remove redundant imports
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from app.core.workflow.nodes.code.node import CodeNode
|
from app.core.workflow.nodes.code.node import CodeNode
|
||||||
|
|
||||||
__all__ = ["CodeNode"]
|
__all__ = ["CodeNode"]
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from textwrap import dedent
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from sympy.physics.vector import vlatex
|
|
||||||
|
|
||||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
|
|||||||
|
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
|
||||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class IfElseNode(BaseNode):
|
class IfElseNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: IfElseNodeConfig | None= None
|
self.typed_config: IfElseNodeConfig | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import uuid
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
|||||||
return await MemoryAgentService().read_memory(
|
return await MemoryAgentService().read_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=self._render_template(self.typed_config.message, state),
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=str(self.typed_config.config_id),
|
config_id=self.typed_config.config_id,
|
||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
db=db,
|
db=db,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pydantic import Field, BaseModel
|
|||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
class ClassifierConfig(BaseModel):
|
class ClassifierConfig(BaseModel):
|
||||||
"""分类器节点配置"""
|
"""分类器节点配置"""
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ class ClassifierConfig(BaseModel):
|
|||||||
|
|
||||||
class QuestionClassifierNodeConfig(BaseNodeConfig):
|
class QuestionClassifierNodeConfig(BaseNodeConfig):
|
||||||
"""问题分类器节点配置"""
|
"""问题分类器节点配置"""
|
||||||
|
|
||||||
model_id: uuid.UUID = Field(..., description="LLM模型ID")
|
model_id: uuid.UUID = Field(..., description="LLM模型ID")
|
||||||
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
|
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
|
||||||
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")
|
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")
|
||||||
|
|||||||
@@ -18,30 +18,30 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
|||||||
|
|
||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = config.api_keys[0]
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
base_url = api_config.api_base
|
base_url = api_config.api_base
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
return RedBearLLM(
|
return RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -64,7 +64,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
||||||
category_map[category_name] = case_tag
|
category_map[category_name] = case_tag
|
||||||
return category_map
|
return category_map
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict:
|
async def execute(self, state: WorkflowState) -> dict:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||||
@@ -74,11 +74,12 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
categories = self.typed_config.categories or []
|
categories = self.typed_config.categories or []
|
||||||
category_names = [class_item.class_name.strip() for class_item in categories]
|
category_names = [class_item.class_name.strip() for class_item in categories]
|
||||||
category_count = len(category_names)
|
category_count = len(category_names)
|
||||||
|
|
||||||
if not question:
|
if not question:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
||||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE}"
|
||||||
|
f"分类总数: {category_count})"
|
||||||
)
|
)
|
||||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||||
if category_count > 0:
|
if category_count > 0:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
from app.core.workflow.nodes.tool.node import ToolNode
|
from app.core.workflow.nodes.tool.node import ToolNode
|
||||||
|
|
||||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
|||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: ToolNodeConfig | None = None
|
self.typed_config: ToolNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行工具"""
|
"""执行工具"""
|
||||||
self.typed_config = ToolNodeConfig(**self.config)
|
self.typed_config = ToolNodeConfig(**self.config)
|
||||||
@@ -28,21 +28,21 @@ class ToolNode(BaseNode):
|
|||||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||||
user_id = self.get_variable("sys.user_id", state)
|
user_id = self.get_variable("sys.user_id", state)
|
||||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||||
|
|
||||||
# 如果没有租户ID,尝试从工作流ID获取
|
# 如果没有租户ID,尝试从工作流ID获取
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"data": "缺少租户ID"
|
"data": "缺少租户ID"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||||
@@ -55,9 +55,9 @@ class ToolNode(BaseNode):
|
|||||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||||
rendered_value = param_template
|
rendered_value = param_template
|
||||||
rendered_parameters[param_name] = rendered_value
|
rendered_parameters[param_name] = rendered_value
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||||
|
|
||||||
# 执行工具
|
# 执行工具
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
tool_service = ToolService(db)
|
tool_service = ToolService(db)
|
||||||
@@ -79,7 +79,7 @@ class ToolNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
return {
|
return {
|
||||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||||
"error_code": result.error_code,
|
"error_code": result.error_code,
|
||||||
"execution_time": result.execution_time
|
"execution_time": result.execution_time
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user