style(service): workflow
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import uuid
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Any, List, Dict, TYPE_CHECKING
|
import uuid
|
||||||
|
from typing import Optional, Any, List, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
|
||||||
@@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel):
|
|||||||
class KnowledgeRetrievalConfig(BaseModel):
|
class KnowledgeRetrievalConfig(BaseModel):
|
||||||
"""知识库检索配置(支持多个知识库,每个有独立配置)"""
|
"""知识库检索配置(支持多个知识库,每个有独立配置)"""
|
||||||
knowledge_bases: List[KnowledgeBaseConfig] = Field(
|
knowledge_bases: List[KnowledgeBaseConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="关联的知识库列表,每个知识库有独立配置"
|
description="关联的知识库列表,每个知识库有独立配置"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 多知识库融合策略
|
# 多知识库融合策略
|
||||||
merge_strategy: str = Field(
|
merge_strategy: str = Field(
|
||||||
default="weighted",
|
default="weighted",
|
||||||
description="多知识库结果融合策略: weighted | rrf | concat"
|
description="多知识库结果融合策略: weighted | rrf | concat"
|
||||||
)
|
)
|
||||||
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
|
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
|
||||||
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
|
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ToolConfig(BaseModel):
|
class ToolConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
@@ -63,7 +63,7 @@ class VariableDefinition(BaseModel):
|
|||||||
name: str = Field(..., description="变量名称(标识符)")
|
name: str = Field(..., description="变量名称(标识符)")
|
||||||
display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)")
|
display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)")
|
||||||
type: str = Field(
|
type: str = Field(
|
||||||
default="string",
|
default="string",
|
||||||
description="变量类型: string(单行文本) | text(多行文本) | number(数字)"
|
description="变量类型: string(单行文本) | text(多行文本) | number(数字)"
|
||||||
)
|
)
|
||||||
required: bool = Field(default=False, description="是否必填")
|
required: bool = Field(default=False, description="是否必填")
|
||||||
@@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel):
|
|||||||
"""Agent 行为配置"""
|
"""Agent 行为配置"""
|
||||||
# 提示词配置
|
# 提示词配置
|
||||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则")
|
system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则")
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID")
|
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID")
|
||||||
model_parameters: ModelParameters = Field(
|
model_parameters: ModelParameters = Field(
|
||||||
default_factory=ModelParameters,
|
default_factory=ModelParameters,
|
||||||
description="模型参数配置(temperature、max_tokens 等)"
|
description="模型参数配置(temperature、max_tokens 等)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 知识库关联
|
# 知识库关联
|
||||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="知识库检索配置"
|
description="知识库检索配置"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记忆配置
|
# 记忆配置
|
||||||
memory: MemoryConfig = Field(
|
memory: MemoryConfig = Field(
|
||||||
default_factory=lambda: MemoryConfig(enabled=True),
|
default_factory=lambda: MemoryConfig(enabled=True),
|
||||||
description="对话历史记忆配置"
|
description="对话历史记忆配置"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 变量配置
|
# 变量配置
|
||||||
variables: List[VariableDefinition] = Field(
|
variables: List[VariableDefinition] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Agent 可用的变量列表"
|
description="Agent 可用的变量列表"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Dict[str, ToolConfig] = Field(
|
tools: Dict[str, ToolConfig] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
@@ -120,7 +120,7 @@ class AppCreate(BaseModel):
|
|||||||
|
|
||||||
# only for type=agent
|
# only for type=agent
|
||||||
agent_config: Optional[AgentConfigCreate] = None
|
agent_config: Optional[AgentConfigCreate] = None
|
||||||
|
|
||||||
# only for type=multi_agent
|
# only for type=multi_agent
|
||||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel):
|
|||||||
"""更新 Agent 行为配置"""
|
"""更新 Agent 行为配置"""
|
||||||
# 提示词配置
|
# 提示词配置
|
||||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID")
|
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID")
|
||||||
model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置")
|
model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置")
|
||||||
|
|
||||||
# 知识库关联
|
# 知识库关联
|
||||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="知识库检索配置"
|
description="知识库检索配置"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记忆配置
|
# 记忆配置
|
||||||
memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置")
|
memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置")
|
||||||
|
|
||||||
# 变量配置
|
# 变量配置
|
||||||
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ class App(BaseModel):
|
|||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -197,26 +197,26 @@ class AgentConfig(BaseModel):
|
|||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
app_id: uuid.UUID
|
app_id: uuid.UUID
|
||||||
|
|
||||||
# 提示词
|
# 提示词
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
default_model_config_id: Optional[uuid.UUID] = None
|
default_model_config_id: Optional[uuid.UUID] = None
|
||||||
model_parameters: ModelParameters = Field(default_factory=ModelParameters)
|
model_parameters: ModelParameters = Field(default_factory=ModelParameters)
|
||||||
|
|
||||||
# 知识库检索
|
# 知识库检索
|
||||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None
|
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None
|
||||||
|
|
||||||
# 记忆配置
|
# 记忆配置
|
||||||
memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True))
|
memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True))
|
||||||
|
|
||||||
# 变量配置
|
# 变量配置
|
||||||
variables: List[VariableDefinition] = []
|
variables: List[VariableDefinition] = []
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Dict[str, ToolConfig] = {}
|
tools: Dict[str, ToolConfig] = {}
|
||||||
|
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
@@ -228,7 +228,7 @@ class AgentConfig(BaseModel):
|
|||||||
if v is None:
|
if v is None:
|
||||||
return ModelParameters()
|
return ModelParameters()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("memory", mode="before")
|
@field_validator("memory", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_memory(cls, v):
|
def validate_memory(cls, v):
|
||||||
@@ -236,7 +236,7 @@ class AgentConfig(BaseModel):
|
|||||||
if v is None:
|
if v is None:
|
||||||
return MemoryConfig(enabled=True)
|
return MemoryConfig(enabled=True)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("variables", mode="before")
|
@field_validator("variables", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_variables(cls, v):
|
def validate_variables(cls, v):
|
||||||
@@ -244,7 +244,7 @@ class AgentConfig(BaseModel):
|
|||||||
if v is None:
|
if v is None:
|
||||||
return []
|
return []
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("tools", mode="before")
|
@field_validator("tools", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_tools(cls, v):
|
def validate_tools(cls, v):
|
||||||
@@ -256,7 +256,7 @@ class AgentConfig(BaseModel):
|
|||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -294,15 +294,15 @@ class AppRelease(BaseModel):
|
|||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("published_at", when_used="json")
|
@field_serializer("published_at", when_used="json")
|
||||||
def _serialize_published_at(self, dt: datetime.datetime):
|
def _serialize_published_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ---------- App Share Schemas ----------
|
# ---------- App Share Schemas ----------
|
||||||
|
|
||||||
@@ -314,7 +314,7 @@ class AppShareCreate(BaseModel):
|
|||||||
class AppShare(BaseModel):
|
class AppShare(BaseModel):
|
||||||
"""应用分享输出"""
|
"""应用分享输出"""
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
source_app_id: uuid.UUID
|
source_app_id: uuid.UUID
|
||||||
source_workspace_id: uuid.UUID
|
source_workspace_id: uuid.UUID
|
||||||
@@ -322,11 +322,11 @@ class AppShare(BaseModel):
|
|||||||
shared_by: uuid.UUID
|
shared_by: uuid.UUID
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -338,6 +338,7 @@ class DraftRunRequest(BaseModel):
|
|||||||
"""试运行请求"""
|
"""试运行请求"""
|
||||||
message: str = Field(..., description="用户消息")
|
message: str = Field(..., description="用户消息")
|
||||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||||
|
conversation_vars: Optional[dict[str, Any]] = Field(default=None, description="会话变量")
|
||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
@@ -382,14 +383,14 @@ class DraftRunCompareRequest(BaseModel):
|
|||||||
conversation_id: Optional[str] = Field(None, description="会话ID")
|
conversation_id: Optional[str] = Field(None, description="会话ID")
|
||||||
user_id: Optional[str] = Field(None, description="用户ID")
|
user_id: Optional[str] = Field(None, description="用户ID")
|
||||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||||
|
|
||||||
models: List[ModelCompareItem] = Field(
|
models: List[ModelCompareItem] = Field(
|
||||||
...,
|
...,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
max_length=5,
|
max_length=5,
|
||||||
description="要对比的模型列表(1-5个)"
|
description="要对比的模型列表(1-5个)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parallel: bool = Field(True, description="是否并行执行")
|
parallel: bool = Field(True, description="是否并行执行")
|
||||||
stream: bool = Field(False, description="是否流式返回")
|
stream: bool = Field(False, description="是否流式返回")
|
||||||
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
||||||
@@ -400,14 +401,14 @@ class ModelRunResult(BaseModel):
|
|||||||
model_config_id: uuid.UUID
|
model_config_id: uuid.UUID
|
||||||
model_name: str
|
model_name: str
|
||||||
label: Optional[str] = None
|
label: Optional[str] = None
|
||||||
|
|
||||||
parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数")
|
parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数")
|
||||||
|
|
||||||
message: Optional[str] = None
|
message: Optional[str] = None
|
||||||
usage: Optional[Dict[str, Any]] = None
|
usage: Optional[Dict[str, Any]] = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
tokens_per_second: Optional[float] = None
|
tokens_per_second: Optional[float] = None
|
||||||
cost_estimate: Optional[float] = None
|
cost_estimate: Optional[float] = None
|
||||||
conversation_id: Optional[str] = None
|
conversation_id: Optional[str] = None
|
||||||
@@ -416,10 +417,10 @@ class ModelRunResult(BaseModel):
|
|||||||
class DraftRunCompareResponse(BaseModel):
|
class DraftRunCompareResponse(BaseModel):
|
||||||
"""多模型对比响应"""
|
"""多模型对比响应"""
|
||||||
results: List[ModelRunResult]
|
results: List[ModelRunResult]
|
||||||
|
|
||||||
total_elapsed_time: float
|
total_elapsed_time: float
|
||||||
successful_count: int
|
successful_count: int
|
||||||
failed_count: int
|
failed_count: int
|
||||||
|
|
||||||
fastest_model: Optional[str] = None
|
fastest_model: Optional[str] = None
|
||||||
cheapest_model: Optional[str] = None
|
cheapest_model: Optional[str] = None
|
||||||
|
|||||||
@@ -39,14 +39,14 @@ class WorkflowService:
|
|||||||
# ==================== 配置管理 ====================
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
def create_workflow_config(
|
def create_workflow_config(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
nodes: list[dict[str, Any]],
|
nodes: list[dict[str, Any]],
|
||||||
edges: list[dict[str, Any]],
|
edges: list[dict[str, Any]],
|
||||||
variables: list[dict[str, Any]] | None = None,
|
variables: list[dict[str, Any]] | None = None,
|
||||||
execution_config: dict[str, Any] | None = None,
|
execution_config: dict[str, Any] | None = None,
|
||||||
triggers: list[dict[str, Any]] | None = None,
|
triggers: list[dict[str, Any]] | None = None,
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
) -> WorkflowConfig:
|
) -> WorkflowConfig:
|
||||||
"""创建工作流配置
|
"""创建工作流配置
|
||||||
|
|
||||||
@@ -109,14 +109,14 @@ class WorkflowService:
|
|||||||
return self.config_repo.get_by_app_id(app_id)
|
return self.config_repo.get_by_app_id(app_id)
|
||||||
|
|
||||||
def update_workflow_config(
|
def update_workflow_config(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
nodes: list[dict[str, Any]] | None = None,
|
nodes: list[dict[str, Any]] | None = None,
|
||||||
edges: list[dict[str, Any]] | None = None,
|
edges: list[dict[str, Any]] | None = None,
|
||||||
variables: list[dict[str, Any]] | None = None,
|
variables: list[dict[str, Any]] | None = None,
|
||||||
execution_config: dict[str, Any] | None = None,
|
execution_config: dict[str, Any] | None = None,
|
||||||
triggers: list[dict[str, Any]] | None = None,
|
triggers: list[dict[str, Any]] | None = None,
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
) -> WorkflowConfig:
|
) -> WorkflowConfig:
|
||||||
"""更新工作流配置
|
"""更新工作流配置
|
||||||
|
|
||||||
@@ -226,8 +226,8 @@ class WorkflowService:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def validate_workflow_config_for_publish(
|
def validate_workflow_config_for_publish(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID
|
app_id: uuid.UUID
|
||||||
) -> tuple[bool, list[str]]:
|
) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置是否可以发布
|
"""验证工作流配置是否可以发布
|
||||||
|
|
||||||
@@ -260,13 +260,13 @@ class WorkflowService:
|
|||||||
# ==================== 执行管理 ====================
|
# ==================== 执行管理 ====================
|
||||||
|
|
||||||
def create_execution(
|
def create_execution(
|
||||||
self,
|
self,
|
||||||
workflow_config_id: uuid.UUID,
|
workflow_config_id: uuid.UUID,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
trigger_type: str,
|
trigger_type: str,
|
||||||
triggered_by: uuid.UUID | None = None,
|
triggered_by: uuid.UUID | None = None,
|
||||||
conversation_id: uuid.UUID | None = None,
|
conversation_id: uuid.UUID | None = None,
|
||||||
input_data: dict[str, Any] | None = None
|
input_data: dict[str, Any] | None = None
|
||||||
) -> WorkflowExecution:
|
) -> WorkflowExecution:
|
||||||
"""创建工作流执行记录
|
"""创建工作流执行记录
|
||||||
|
|
||||||
@@ -314,10 +314,10 @@ class WorkflowService:
|
|||||||
return self.execution_repo.get_by_execution_id(execution_id)
|
return self.execution_repo.get_by_execution_id(execution_id)
|
||||||
|
|
||||||
def get_executions_by_app(
|
def get_executions_by_app(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0
|
offset: int = 0
|
||||||
) -> list[WorkflowExecution]:
|
) -> list[WorkflowExecution]:
|
||||||
"""获取应用的执行记录列表
|
"""获取应用的执行记录列表
|
||||||
|
|
||||||
@@ -332,12 +332,12 @@ class WorkflowService:
|
|||||||
return self.execution_repo.get_by_app_id(app_id, limit, offset)
|
return self.execution_repo.get_by_app_id(app_id, limit, offset)
|
||||||
|
|
||||||
def update_execution_status(
|
def update_execution_status(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
output_data: dict[str, Any] | None = None,
|
output_data: dict[str, Any] | None = None,
|
||||||
error_message: str | None = None,
|
error_message: str | None = None,
|
||||||
error_node_id: str | None = None
|
error_node_id: str | None = None
|
||||||
) -> WorkflowExecution:
|
) -> WorkflowExecution:
|
||||||
"""更新执行状态
|
"""更新执行状态
|
||||||
|
|
||||||
@@ -407,10 +407,10 @@ class WorkflowService:
|
|||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
payload: DraftRunRequest,
|
payload: DraftRunRequest,
|
||||||
config: WorkflowConfig
|
config: WorkflowConfig
|
||||||
):
|
):
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
@@ -527,10 +527,10 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
payload: DraftRunRequest,
|
payload: DraftRunRequest,
|
||||||
config: WorkflowConfig
|
config: WorkflowConfig
|
||||||
):
|
):
|
||||||
"""运行工作流(流式)
|
"""运行工作流(流式)
|
||||||
|
|
||||||
@@ -600,11 +600,11 @@ class WorkflowService:
|
|||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||||
async for event in self._run_workflow_stream(
|
async for event in self._run_workflow_stream(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id="",
|
workspace_id="",
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id
|
||||||
):
|
):
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
# 直接转发 executor 的事件(已经是正确的格式)
|
||||||
yield event
|
yield event
|
||||||
@@ -626,12 +626,12 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
triggered_by: uuid.UUID,
|
triggered_by: uuid.UUID,
|
||||||
conversation_id: uuid.UUID | None = None,
|
conversation_id: uuid.UUID | None = None,
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
) -> AsyncGenerator | dict:
|
) -> AsyncGenerator | dict:
|
||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
@@ -778,12 +778,12 @@ class WorkflowService:
|
|||||||
return clean_value(event)
|
return clean_value(event)
|
||||||
|
|
||||||
async def _run_workflow_stream(
|
async def _run_workflow_stream(
|
||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str):
|
user_id: str):
|
||||||
"""运行工作流(流式,内部方法)
|
"""运行工作流(流式,内部方法)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -800,11 +800,11 @@ class WorkflowService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
):
|
):
|
||||||
# 直接转发事件(executor 已经返回正确格式)
|
# 直接转发事件(executor 已经返回正确格式)
|
||||||
yield event
|
yield event
|
||||||
@@ -828,7 +828,7 @@ class WorkflowService:
|
|||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
def get_workflow_service(
|
def get_workflow_service(
|
||||||
db: Annotated[Session, Depends(get_db)]
|
db: Annotated[Session, Depends(get_db)]
|
||||||
) -> WorkflowService:
|
) -> WorkflowService:
|
||||||
"""获取工作流服务(依赖注入)"""
|
"""获取工作流服务(依赖注入)"""
|
||||||
return WorkflowService(db)
|
return WorkflowService(db)
|
||||||
|
|||||||
Reference in New Issue
Block a user