diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index d143c693..1d00532e 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -5,9 +5,11 @@ """ from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.if_else import IfElseNode +# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode @@ -23,5 +25,7 @@ __all__ = [ "StartNode", "EndNode", "NodeFactory", - "WorkflowNode" + "WorkflowNode", + # "KnowledgeRetrievalNode", + "AssignerNode", ] diff --git a/api/app/core/workflow/nodes/assigner/__init__.py b/api/app/core/workflow/nodes/assigner/__init__.py new file mode 100644 index 00000000..668e1aea --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig +from app.core.workflow.nodes.assigner.node import AssignerNode + +__all__ = ["AssignerNode", "AssignerNodeConfig"] diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py new file mode 100644 index 00000000..1cb0def3 --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -0,0 +1,21 @@ +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.nodes.enums import AssignmentOperator + + +class AssignerNodeConfig(BaseNodeConfig): + variable_selector: str | list[str] = Field( + ..., + description="Variables to be assigned", + ) + + operation: AssignmentOperator = Field( + ..., + description="Operator to assign", + ) + + value: str | list[str] = Field( + ..., + description="Values to assign", + ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py new file mode 100644 index 00000000..eb32bf8b --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +from app.core.workflow.expression_evaluator import ExpressionEvaluator +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import AssignmentOperator +from app.core.workflow.nodes.operators import AssignmentOperatorInstance +from app.core.workflow.variable_pool import VariablePool + +logger = logging.getLogger(__name__) + + +class AssignerNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = AssignerNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the assignment operation defined by this node. + + Args: + state: The current workflow state, including conversation variables, + node outputs, and system variables. + + Returns: + None or the result of the assignment operation. + """ + # Initialize a variable pool for accessing conversation, node, and system variables + pool = VariablePool(state) + + # Get the target variable selector (e.g., "conv.test") + variable_selector = self.typed_config.variable_selector + if isinstance(variable_selector, str): + # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] + variable_selector = variable_selector.split('.') + + # Only conversation variables ('conv') are allowed + if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) + raise ValueError("Only conversation variables can be assigned.") + + # Get the value or expression to assign + value = self.typed_config.value + if isinstance(value, list): + value = '.'.join(value) + value = ExpressionEvaluator.evaluate( + expression=value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + + # Select the appropriate assignment operator instance based on the target variable type + operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + pool, variable_selector, value + ) + + # Execute the configured assignment operation + match self.typed_config.operation: + case AssignmentOperator.ASSIGN: + operator.assign() + case AssignmentOperator.CLEAR: + operator.clear() + case AssignmentOperator.ADD: + operator.add() + case AssignmentOperator.SUBTRACT: + operator.subtract() + case AssignmentOperator.MULTIPLY: + operator.multiply() + case AssignmentOperator.DIVIDE: + operator.divide() + case AssignmentOperator.APPEND: + operator.append() + case AssignmentOperator.REMOVE_FIRST: + operator.remove_first() + case AssignmentOperator.REMOVE_LAST: + operator.remove_last() + case _: + raise ValueError(f"Invalid Operator: {self.typed_config.operation}") diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 15ab0ce9..ecded070 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig +# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig __all__ = [ # 基础类 @@ -28,4 +30,6 @@ __all__ = [ "AgentNodeConfig", "TransformNodeConfig", "IfElseNodeConfig", + # "KnowledgeRetrievalNodeConfig", + "AssignerNodeConfig", ] diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2e5758eb..3cece96b 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -33,7 +33,7 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) @@ -45,17 +45,17 @@ class EndNode(BaseNode): total_nodes = len(node_outputs) logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - + return output - + def _extract_referenced_nodes(self, template: str) -> list[str]: """从模板中提取引用的节点 ID - + 例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] - + Args: template: 模板字符串 - + Returns: 引用的节点 ID 列表 """ @@ -63,44 +63,44 @@ class EndNode(BaseNode): pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' matches = re.findall(pattern, template) return list(set(matches)) # 去重 - + def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: """解析模板,分离静态文本和动态引用 - + 例如:'你好 {{llm.output}}, 这是后缀' 返回:[ {"type": "static", "content": "你好 "}, {"type": "dynamic", "node_id": "llm", "field": "output"}, {"type": "static", "content": ", 这是后缀"} ] - + Args: template: 模板字符串 state: 工作流状态 - + Returns: 模板部分列表 """ import re - + parts = [] last_end = 0 - + # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) pattern = r'\{\{\s*([^}]+?)\s*\}\}' - + for match in re.finditer(pattern, template): start, end = match.span() - + # 添加前面的静态文本 if start > last_end: static_text = template[last_end:start] if static_text: parts.append({"type": "static", "content": static_text}) - + # 解析动态引用 ref = match.group(1).strip() - + # 检查是否是节点引用(如 llm.output 或 llm_qa.output) if '.' in ref: node_id, field = ref.split('.', 1) @@ -115,62 +115,62 @@ class EndNode(BaseNode): # 直接渲染这部分 rendered = self._render_template(f"{{{{{ref}}}}}", state) parts.append({"type": "static", "content": rendered}) - + last_end = end - + # 添加最后的静态文本 if last_end < len(template): static_text = template[last_end:] if static_text: parts.append({"type": "static", "content": static_text}) - + return parts - + async def execute_stream(self, state: WorkflowState): """流式执行 end 节点业务逻辑 - + 智能输出策略: 1. 检测模板中是否引用了直接上游节点 2. 如果引用了,只输出该引用**之后**的部分(后缀) 3. 前缀和引用内容已经在上游节点流式输出时发送了 - + 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' - 直接上游节点是 llm_qa - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - LLM 内容在 LLM 节点流式输出 - End 节点只输出 ' lalalalala a'(后缀,一次性输出) - + Args: state: 工作流状态 - + Yields: 完成标记 """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") - + # 获取配置的输出模板 output_template = self.config.get("output") - + if not output_template: output = "工作流已完成" yield {"__final__": True, "result": output} return - + # 找到直接上游节点 direct_upstream_nodes = [] for edge in self.workflow_config.get("edges", []): if edge.get("target") == self.node_id: source_node_id = edge.get("source") direct_upstream_nodes.append(source_node_id) - + logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") - + # 解析模板部分 parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") for i, part in enumerate(parts): logger.info(f"[模板解析] part[{i}]: {part}") - + # 找到第一个引用直接上游节点的动态引用 upstream_ref_index = None for i, part in enumerate(parts): @@ -178,12 +178,12 @@ class EndNode(BaseNode): upstream_ref_index = i logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") break - + if upstream_ref_index is None: # 没有引用直接上游节点,输出完整模板内容 output = self._render_template(output_template, state) logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") - + # 通过 writer 发送完整内容(作为一个 message chunk) from langgraph.config import get_stream_writer writer = get_stream_writer() @@ -196,14 +196,14 @@ class EndNode(BaseNode): "is_suffix": False }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") - + # yield 完成标记 yield {"__final__": True, "result": output} return - + # 有引用直接上游节点,只输出该引用之后的部分(后缀) logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") - + # 收集后缀部分 suffix_parts = [] logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") @@ -214,7 +214,7 @@ class EndNode(BaseNode): # 静态文本 logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") suffix_parts.append(part["content"]) - + elif part["type"] == "dynamic": # Other dynamic references (if there are multiple references) node_id = part["node_id"] @@ -229,21 +229,21 @@ class EndNode(BaseNode): except Exception as e: logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") content = "" - + # Convert to string if not None suffix_parts.append(str(content) if content is not None else "") # 拼接后缀 suffix = "".join(suffix_parts) - + # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state) - + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") - + if suffix: logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") # 一次性输出后缀(作为单个 chunk) @@ -266,8 +266,8 @@ class EndNode(BaseNode): # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") - + # yield 完成标记(包含完整输出) yield {"__final__": True, "result": full_output} diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index af5ddbaa..82ecad5d 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,5 +1,14 @@ from enum import StrEnum +from app.core.workflow.nodes.operators import ( + StringOperator, + NumberOperator, + AssignmentOperatorType, + BooleanOperator, + ArrayOperator, + ObjectOperator +) + class NodeType(StrEnum): START = "start" @@ -14,6 +23,7 @@ class NodeType(StrEnum): HTTP_REQUEST = "http-request" TOOL = "tool" AGENT = "agent" + ASSIGNER = "assigner" class ComparisonOperator(StrEnum): @@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum): class LogicOperator(StrEnum): AND = "and" OR = "or" + + +class AssignmentOperator(StrEnum): + ASSIGN = "assign" + CLEAR = "clear" + + ADD = "add" # += + SUBTRACT = "subtract" # -= + MULTIPLY = "multiply" # *= + DIVIDE = "divide" # /= + + APPEND = "append" + REMOVE_LAST = "remove_last" + REMOVE_FIRST = "remove_first" + + @classmethod + def get_operator(cls, obj) -> AssignmentOperatorType: + if isinstance(obj, str): + return StringOperator + elif isinstance(obj, bool): + return BooleanOperator + elif isinstance(obj, (int, float)): + return NumberOperator + elif isinstance(obj, list): + return ArrayOperator + elif isinstance(obj, dict): + return ObjectOperator + + raise TypeError(f"Unsupported variable type ({type(obj)})") diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index ed3dbbd6..aedf0727 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -1,7 +1,7 @@ import logging from typing import Any -from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import ConditionDetail diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 8f809923..65826d84 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig from app.db import get_db_context +from app.models import ModelType from app.services.model_service import ModelConfigService from app.core.exceptions import BusinessException @@ -136,7 +137,7 @@ class LLMNode(BaseNode): base_url=api_base, extra_params=extra_params ), - type=model_type + type=ModelType(model_type) ) logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 1abace67..93364083 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,6 +7,7 @@ import logging from typing import Any, Union +# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.assigner import AssignerNode logger = logging.getLogger(__name__) @@ -26,6 +28,8 @@ WorkflowNode = Union[ IfElseNode, AgentNode, TransformNode, + AssignerNode, + # KnowledgeRetrievalNode, ] @@ -42,7 +46,9 @@ class NodeFactory: NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, - NodeType.IF_ELSE: IfElseNode + NodeType.IF_ELSE: IfElseNode, + # NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.ASSIGNER: AssignerNode, } @classmethod @@ -82,10 +88,6 @@ class NodeFactory: """ node_type = node_config.get("type") - # 跳过条件节点(由 LangGraph 处理) - if node_type == "condition": - return None - # 获取节点类 node_class = cls._node_types.get(node_type) if not node_class: diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py new file mode 100644 index 00000000..a80cf326 --- /dev/null +++ b/api/app/core/workflow/nodes/operators.py @@ -0,0 +1,146 @@ +from abc import ABC +from typing import Union, Type + +from app.core.workflow.variable_pool import VariablePool + + +class OperatorBase(ABC): + def __init__(self, pool: VariablePool, left_selector, right): + self.pool = pool + self.left_selector = left_selector + self.right = right + + self.type_limit: type[str, int, dict, list] = None + + def check(self, no_right=False): + left = self.pool.get(self.left_selector) + if not isinstance(left, self.type_limit): + raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") + + if not no_right and not isinstance(self.right, self.type_limit): + raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") + + +class StringOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = str + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, '') + + +class NumberOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = (float, int) + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, 0) + + def add(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin + self.right) + + def subtract(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin - self.right) + + def multiply(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin * self.right) + + def divide(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin / self.right) + + +class BooleanOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = bool + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, False) + + +class ArrayOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = list + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, list()) + + def append(self) -> None: + self.check(no_right=True) + # TODO:require type limit in list + origin = self.pool.get(self.left_selector) + origin.append(self.right) + self.pool.set(self.left_selector, origin) + + def extend(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + origin.extend(self.right) + self.pool.set(self.left_selector, origin) + + def remove_last(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + origin.pop() + self.pool.set(self.left_selector, origin) + + def remove_first(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + origin.pop(0) + self.pool.set(self.left_selector, origin) + + +class ObjectOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = object + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, dict()) + + +AssignmentOperatorInstance = Union[ + StringOperator, + NumberOperator, + BooleanOperator, + ArrayOperator, + ObjectOperator +] +AssignmentOperatorType = Type[AssignmentOperatorInstance] diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 1f589dab..0f97c349 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -10,7 +10,10 @@ """ import logging -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from app.core.workflow.nodes import WorkflowState logger = logging.getLogger(__name__) @@ -82,7 +85,7 @@ class VariablePool: >>> pool.set(["conv", "user_name"], "张三") """ - def __init__(self, state: dict[str, Any]): + def __init__(self, state: "WorkflowState"): """初始化变量池 Args: diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 91c1d9c7..2e60ef1c 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -15,25 +15,6 @@ class ModelType(StrEnum): EMBEDDING = "embedding" RERANK = "rerank" - @classmethod - def from_str(cls, value: str) -> "ModelType": - """ - Get a ModelType enum instance from a string value. - - Args: - value (str): The string representation of the model type. - - Returns: - ModelType: The corresponding ModelType enum object. - - Raises: - ValueError: If the given value does not match any ModelType. - """ - try: - return cls(value) - except ValueError: - raise ValueError(f"Invalid ModelType: {value}") - class ModelProvider(StrEnum): """模型提供商枚举""" diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index c387cee9..52c5ae81 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,7 @@ -import uuid import datetime -from typing import Optional, Any, List, Dict, TYPE_CHECKING +import uuid +from typing import Optional, Any, List, Dict + from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel): class KnowledgeRetrievalConfig(BaseModel): """知识库检索配置(支持多个知识库,每个有独立配置)""" knowledge_bases: List[KnowledgeBaseConfig] = Field( - default_factory=list, + default_factory=list, description="关联的知识库列表,每个知识库有独立配置" ) - + # 多知识库融合策略 merge_strategy: str = Field( - default="weighted", + default="weighted", description="多知识库结果融合策略: weighted | rrf | concat" ) reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") - class ToolConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -63,7 +63,7 @@ class VariableDefinition(BaseModel): name: str = Field(..., description="变量名称(标识符)") display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)") type: str = Field( - default="string", + default="string", description="变量类型: string(单行文本) | text(多行文本) | number(数字)" ) required: bool = Field(default=False, description="是否必填") @@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel): """Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID") model_parameters: ModelParameters = Field( default_factory=ModelParameters, description="模型参数配置(temperature、max_tokens 等)" ) - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: MemoryConfig = Field( default_factory=lambda: MemoryConfig(enabled=True), description="对话历史记忆配置" ) - + # 变量配置 variables: List[VariableDefinition] = Field( default_factory=list, description="Agent 可用的变量列表" ) - + # 工具配置 tools: Dict[str, ToolConfig] = Field( default_factory=dict, @@ -120,7 +120,7 @@ class AppCreate(BaseModel): # only for type=agent agent_config: Optional[AgentConfigCreate] = None - + # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None @@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel): """更新 Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置") - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置") - + # 变量配置 variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") - + # 工具配置 tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") @@ -185,7 +185,7 @@ class App(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -197,26 +197,26 @@ class AgentConfig(BaseModel): id: uuid.UUID app_id: uuid.UUID - + # 提示词 system_prompt: Optional[str] = None - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = None model_parameters: ModelParameters = Field(default_factory=ModelParameters) - + # 知识库检索 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None - + # 记忆配置 memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True)) - + # 变量配置 variables: List[VariableDefinition] = [] - + # 工具配置 tools: Dict[str, ToolConfig] = {} - + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -228,7 +228,7 @@ class AgentConfig(BaseModel): if v is None: return ModelParameters() return v - + @field_validator("memory", mode="before") @classmethod def validate_memory(cls, v): @@ -236,7 +236,7 @@ class AgentConfig(BaseModel): if v is None: return MemoryConfig(enabled=True) return v - + @field_validator("variables", mode="before") @classmethod def validate_variables(cls, v): @@ -244,7 +244,7 @@ class AgentConfig(BaseModel): if v is None: return [] return v - + @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v): @@ -256,7 +256,7 @@ class AgentConfig(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -294,15 +294,15 @@ class AppRelease(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("published_at", when_used="json") def _serialize_published_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + # ---------- App Share Schemas ---------- @@ -314,7 +314,7 @@ class AppShareCreate(BaseModel): class AppShare(BaseModel): """应用分享输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID source_app_id: uuid.UUID source_workspace_id: uuid.UUID @@ -322,11 +322,11 @@ class AppShare(BaseModel): shared_by: uuid.UUID created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -382,14 +382,14 @@ class DraftRunCompareRequest(BaseModel): conversation_id: Optional[str] = Field(None, description="会话ID") user_id: Optional[str] = Field(None, description="用户ID") variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - + models: List[ModelCompareItem] = Field( ..., min_length=1, max_length=5, description="要对比的模型列表(1-5个)" ) - + parallel: bool = Field(True, description="是否并行执行") stream: bool = Field(False, description="是否流式返回") timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") @@ -400,14 +400,14 @@ class ModelRunResult(BaseModel): model_config_id: uuid.UUID model_name: str label: Optional[str] = None - + parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数") - + message: Optional[str] = None usage: Optional[Dict[str, Any]] = None elapsed_time: float error: Optional[str] = None - + tokens_per_second: Optional[float] = None cost_estimate: Optional[float] = None conversation_id: Optional[str] = None @@ -416,10 +416,10 @@ class ModelRunResult(BaseModel): class DraftRunCompareResponse(BaseModel): """多模型对比响应""" results: List[ModelRunResult] - + total_elapsed_time: float successful_count: int failed_count: int - + fastest_model: Optional[str] = None cheapest_model: Optional[str] = None diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 5355474f..6af794b1 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -169,7 +169,7 @@ class PromptOptimizerService: provider=api_config.provider, api_key=api_config.api_key, base_url=api_config.api_base - ), type=ModelType.from_str(model_config.type)) + ), type=ModelType(model_config.type)) # build message messages = [ diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index ccf0442f..058767d9 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -39,14 +39,14 @@ class WorkflowService: # ==================== 配置管理 ==================== def create_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]], - edges: list[dict[str, Any]], - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """创建工作流配置 @@ -109,14 +109,14 @@ class WorkflowService: return self.config_repo.get_by_app_id(app_id) def update_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]] | None = None, - edges: list[dict[str, Any]] | None = None, - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """更新工作流配置 @@ -226,8 +226,8 @@ class WorkflowService: return config def validate_workflow_config_for_publish( - self, - app_id: uuid.UUID + self, + app_id: uuid.UUID ) -> tuple[bool, list[str]]: """验证工作流配置是否可以发布 @@ -260,13 +260,13 @@ class WorkflowService: # ==================== 执行管理 ==================== def create_execution( - self, - workflow_config_id: uuid.UUID, - app_id: uuid.UUID, - trigger_type: str, - triggered_by: uuid.UUID | None = None, - conversation_id: uuid.UUID | None = None, - input_data: dict[str, Any] | None = None + self, + workflow_config_id: uuid.UUID, + app_id: uuid.UUID, + trigger_type: str, + triggered_by: uuid.UUID | None = None, + conversation_id: uuid.UUID | None = None, + input_data: dict[str, Any] | None = None ) -> WorkflowExecution: """创建工作流执行记录 @@ -314,10 +314,10 @@ class WorkflowService: return self.execution_repo.get_by_execution_id(execution_id) def get_executions_by_app( - self, - app_id: uuid.UUID, - limit: int = 50, - offset: int = 0 + self, + app_id: uuid.UUID, + limit: int = 50, + offset: int = 0 ) -> list[WorkflowExecution]: """获取应用的执行记录列表 @@ -332,12 +332,12 @@ class WorkflowService: return self.execution_repo.get_by_app_id(app_id, limit, offset) def update_execution_status( - self, - execution_id: str, - status: str, - output_data: dict[str, Any] | None = None, - error_message: str | None = None, - error_node_id: str | None = None + self, + execution_id: str, + status: str, + output_data: dict[str, Any] | None = None, + error_message: str | None = None, + error_node_id: str | None = None ) -> WorkflowExecution: """更新执行状态 @@ -407,10 +407,10 @@ class WorkflowService: # ==================== 工作流执行 ==================== async def run( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流 @@ -527,10 +527,10 @@ class WorkflowService: ) async def run_stream( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流(流式) @@ -600,11 +600,11 @@ class WorkflowService: # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id="", - user_id=payload.user_id + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id="", + user_id=payload.user_id ): # 直接转发 executor 的事件(已经是正确的格式) yield event @@ -626,12 +626,12 @@ class WorkflowService: } async def run_workflow( - self, - app_id: uuid.UUID, - input_data: dict[str, Any], - triggered_by: uuid.UUID, - conversation_id: uuid.UUID | None = None, - stream: bool = False + self, + app_id: uuid.UUID, + input_data: dict[str, Any], + triggered_by: uuid.UUID, + conversation_id: uuid.UUID | None = None, + stream: bool = False ) -> AsyncGenerator | dict: """运行工作流 @@ -778,12 +778,12 @@ class WorkflowService: return clean_value(event) async def _run_workflow_stream( - self, - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str): + self, + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str): """运行工作流(流式,内部方法) Args: @@ -800,11 +800,11 @@ class WorkflowService: try: async for event in execute_workflow_stream( - workflow_config=workflow_config, - input_data=input_data, - execution_id=execution_id, - workspace_id=workspace_id, - user_id=user_id + workflow_config=workflow_config, + input_data=input_data, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id ): # 直接转发事件(executor 已经返回正确格式) yield event @@ -828,7 +828,7 @@ class WorkflowService: # ==================== 依赖注入函数 ==================== def get_workflow_service( - db: Annotated[Session, Depends(get_db)] + db: Annotated[Session, Depends(get_db)] ) -> WorkflowService: """获取工作流服务(依赖注入)""" return WorkflowService(db)