From c15a987701d4e740f9f104647fb3ff1c56d2075d Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 14:59:21 +0800 Subject: [PATCH] style(service): workflow --- api/app/schemas/app_schema.py | 87 +++++++++--------- api/app/services/workflow_service.py | 132 +++++++++++++-------------- 2 files changed, 110 insertions(+), 109 deletions(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index c387cee9..b6b1de52 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,7 @@ -import uuid import datetime -from typing import Optional, Any, List, Dict, TYPE_CHECKING +import uuid +from typing import Optional, Any, List, Dict + from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel): class KnowledgeRetrievalConfig(BaseModel): """知识库检索配置(支持多个知识库,每个有独立配置)""" knowledge_bases: List[KnowledgeBaseConfig] = Field( - default_factory=list, + default_factory=list, description="关联的知识库列表,每个知识库有独立配置" ) - + # 多知识库融合策略 merge_strategy: str = Field( - default="weighted", + default="weighted", description="多知识库结果融合策略: weighted | rrf | concat" ) reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") - class ToolConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -63,7 +63,7 @@ class VariableDefinition(BaseModel): name: str = Field(..., description="变量名称(标识符)") display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)") type: str = Field( - default="string", + default="string", description="变量类型: string(单行文本) | text(多行文本) | number(数字)" ) required: bool = Field(default=False, description="是否必填") @@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel): """Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID") model_parameters: ModelParameters = Field( default_factory=ModelParameters, description="模型参数配置(temperature、max_tokens 等)" ) - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: MemoryConfig = Field( default_factory=lambda: MemoryConfig(enabled=True), description="对话历史记忆配置" ) - + # 变量配置 variables: List[VariableDefinition] = Field( default_factory=list, description="Agent 可用的变量列表" ) - + # 工具配置 tools: Dict[str, ToolConfig] = Field( default_factory=dict, @@ -120,7 +120,7 @@ class AppCreate(BaseModel): # only for type=agent agent_config: Optional[AgentConfigCreate] = None - + # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None @@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel): """更新 Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置") - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置") - + # 变量配置 variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") - + # 工具配置 tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") @@ -185,7 +185,7 @@ class App(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -197,26 +197,26 @@ class AgentConfig(BaseModel): id: uuid.UUID app_id: uuid.UUID - + # 提示词 system_prompt: Optional[str] = None - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = None model_parameters: ModelParameters = Field(default_factory=ModelParameters) - + # 知识库检索 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None - + # 记忆配置 memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True)) - + # 变量配置 variables: List[VariableDefinition] = [] - + # 工具配置 tools: Dict[str, ToolConfig] = {} - + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -228,7 +228,7 @@ class AgentConfig(BaseModel): if v is None: return ModelParameters() return v - + @field_validator("memory", mode="before") @classmethod def validate_memory(cls, v): @@ -236,7 +236,7 @@ class AgentConfig(BaseModel): if v is None: return MemoryConfig(enabled=True) return v - + @field_validator("variables", mode="before") @classmethod def validate_variables(cls, v): @@ -244,7 +244,7 @@ class AgentConfig(BaseModel): if v is None: return [] return v - + @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v): @@ -256,7 +256,7 @@ class AgentConfig(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -294,15 +294,15 @@ class AppRelease(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("published_at", when_used="json") def _serialize_published_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + # ---------- App Share Schemas ---------- @@ -314,7 +314,7 @@ class AppShareCreate(BaseModel): class AppShare(BaseModel): """应用分享输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID source_app_id: uuid.UUID source_workspace_id: uuid.UUID @@ -322,11 +322,11 @@ class AppShare(BaseModel): shared_by: uuid.UUID created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -338,6 +338,7 @@ class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") + conversation_vars: Optional[dict[str, Any]] = Field(default=None, description="会话变量") user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") @@ -382,14 +383,14 @@ class DraftRunCompareRequest(BaseModel): conversation_id: Optional[str] = Field(None, description="会话ID") user_id: Optional[str] = Field(None, description="用户ID") variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - + models: List[ModelCompareItem] = Field( ..., min_length=1, max_length=5, description="要对比的模型列表(1-5个)" ) - + parallel: bool = Field(True, description="是否并行执行") stream: bool = Field(False, description="是否流式返回") timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") @@ -400,14 +401,14 @@ class ModelRunResult(BaseModel): model_config_id: uuid.UUID model_name: str label: Optional[str] = None - + parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数") - + message: Optional[str] = None usage: Optional[Dict[str, Any]] = None elapsed_time: float error: Optional[str] = None - + tokens_per_second: Optional[float] = None cost_estimate: Optional[float] = None conversation_id: Optional[str] = None @@ -416,10 +417,10 @@ class ModelRunResult(BaseModel): class DraftRunCompareResponse(BaseModel): """多模型对比响应""" results: List[ModelRunResult] - + total_elapsed_time: float successful_count: int failed_count: int - + fastest_model: Optional[str] = None cheapest_model: Optional[str] = None diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index ccf0442f..058767d9 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -39,14 +39,14 @@ class WorkflowService: # ==================== 配置管理 ==================== def create_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]], - edges: list[dict[str, Any]], - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """创建工作流配置 @@ -109,14 +109,14 @@ class WorkflowService: return self.config_repo.get_by_app_id(app_id) def update_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]] | None = None, - edges: list[dict[str, Any]] | None = None, - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """更新工作流配置 @@ -226,8 +226,8 @@ class WorkflowService: return config def validate_workflow_config_for_publish( - self, - app_id: uuid.UUID + self, + app_id: uuid.UUID ) -> tuple[bool, list[str]]: """验证工作流配置是否可以发布 @@ -260,13 +260,13 @@ class WorkflowService: # ==================== 执行管理 ==================== def create_execution( - self, - workflow_config_id: uuid.UUID, - app_id: uuid.UUID, - trigger_type: str, - triggered_by: uuid.UUID | None = None, - conversation_id: uuid.UUID | None = None, - input_data: dict[str, Any] | None = None + self, + workflow_config_id: uuid.UUID, + app_id: uuid.UUID, + trigger_type: str, + triggered_by: uuid.UUID | None = None, + conversation_id: uuid.UUID | None = None, + input_data: dict[str, Any] | None = None ) -> WorkflowExecution: """创建工作流执行记录 @@ -314,10 +314,10 @@ class WorkflowService: return self.execution_repo.get_by_execution_id(execution_id) def get_executions_by_app( - self, - app_id: uuid.UUID, - limit: int = 50, - offset: int = 0 + self, + app_id: uuid.UUID, + limit: int = 50, + offset: int = 0 ) -> list[WorkflowExecution]: """获取应用的执行记录列表 @@ -332,12 +332,12 @@ class WorkflowService: return self.execution_repo.get_by_app_id(app_id, limit, offset) def update_execution_status( - self, - execution_id: str, - status: str, - output_data: dict[str, Any] | None = None, - error_message: str | None = None, - error_node_id: str | None = None + self, + execution_id: str, + status: str, + output_data: dict[str, Any] | None = None, + error_message: str | None = None, + error_node_id: str | None = None ) -> WorkflowExecution: """更新执行状态 @@ -407,10 +407,10 @@ class WorkflowService: # ==================== 工作流执行 ==================== async def run( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流 @@ -527,10 +527,10 @@ class WorkflowService: ) async def run_stream( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流(流式) @@ -600,11 +600,11 @@ class WorkflowService: # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id="", - user_id=payload.user_id + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id="", + user_id=payload.user_id ): # 直接转发 executor 的事件(已经是正确的格式) yield event @@ -626,12 +626,12 @@ class WorkflowService: } async def run_workflow( - self, - app_id: uuid.UUID, - input_data: dict[str, Any], - triggered_by: uuid.UUID, - conversation_id: uuid.UUID | None = None, - stream: bool = False + self, + app_id: uuid.UUID, + input_data: dict[str, Any], + triggered_by: uuid.UUID, + conversation_id: uuid.UUID | None = None, + stream: bool = False ) -> AsyncGenerator | dict: """运行工作流 @@ -778,12 +778,12 @@ class WorkflowService: return clean_value(event) async def _run_workflow_stream( - self, - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str): + self, + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str): """运行工作流(流式,内部方法) Args: @@ -800,11 +800,11 @@ class WorkflowService: try: async for event in execute_workflow_stream( - workflow_config=workflow_config, - input_data=input_data, - execution_id=execution_id, - workspace_id=workspace_id, - user_id=user_id + workflow_config=workflow_config, + input_data=input_data, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id ): # 直接转发事件(executor 已经返回正确格式) yield event @@ -828,7 +828,7 @@ class WorkflowService: # ==================== 依赖注入函数 ==================== def get_workflow_service( - db: Annotated[Session, Depends(get_db)] + db: Annotated[Session, Depends(get_db)] ) -> WorkflowService: """获取工作流服务(依赖注入)""" return WorkflowService(db)