Merge pull request #142 from SuanmoSuanyangTechnology/feature/workflow-release
Fix workflow release issues and enhance token metrics & loop node outputs
This commit is contained in:
@@ -8,9 +8,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
@@ -19,7 +20,8 @@ from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
@@ -183,7 +185,6 @@ def get_embed_code(
|
||||
return success(data=embed_code)
|
||||
|
||||
|
||||
|
||||
# ---------- 会话管理接口 ----------
|
||||
|
||||
@router.get(
|
||||
@@ -561,13 +562,15 @@ async def chat(
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
with get_db_read() as db:
|
||||
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
|
||||
config.id = source_config.id
|
||||
config.id = uuid.UUID(config.id)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
@@ -578,7 +581,8 @@ async def chat(
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -610,7 +614,8 @@ async def chat(
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
|
||||
@@ -242,8 +242,9 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -274,8 +275,9 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
@@ -53,11 +54,11 @@ class WorkflowExecutor:
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.checkpoint_config = {
|
||||
"configurable": {
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
@@ -214,13 +215,13 @@ class WorkflowExecutor:
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": result.get("variables", {}),
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
"variables": result.get("variables", {}),
|
||||
}
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
@@ -326,11 +327,10 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph(True)
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
@@ -346,14 +346,16 @@ class WorkflowExecutor:
|
||||
mode, data = event
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
@@ -380,7 +382,8 @@ class WorkflowExecutor:
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||
logger.info(f"[NODE-START] Node starts execution: {node_name} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
@@ -399,7 +402,8 @@ class WorkflowExecutor:
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||
logger.info(f"[NODE-END] Node execution completed: {node_name} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
@@ -407,13 +411,15 @@ class WorkflowExecutor:
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
"timestamp": data.get("timestamp"),
|
||||
"state": result.get("node_outputs", {}).get(node_name),
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
@@ -421,7 +427,7 @@ class WorkflowExecutor:
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
)
|
||||
|
||||
# 发送 workflow_end 事件
|
||||
@@ -449,7 +455,8 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
@staticmethod
|
||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
@@ -473,7 +480,8 @@ class WorkflowExecutor:
|
||||
|
||||
return None
|
||||
|
||||
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
|
||||
Args:
|
||||
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: list[dict[str, str]]
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
|
||||
@@ -21,6 +21,7 @@ class IterationRuntime:
|
||||
optional parallel execution, flattening of output, and loop control via
|
||||
the workflow state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: CompiledStateGraph,
|
||||
@@ -87,6 +88,7 @@ class IterationRuntime:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
return result
|
||||
|
||||
def _create_iteration_tasks(self, array_obj, idx):
|
||||
"""
|
||||
@@ -124,7 +126,7 @@ class IterationRuntime:
|
||||
array_obj = VariablePool(self.state).get(input_expression)
|
||||
if not isinstance(array_obj, list):
|
||||
raise RuntimeError("Cannot iterate over a non-list variable")
|
||||
|
||||
child_state = []
|
||||
idx = 0
|
||||
if self.typed_config.parallel:
|
||||
# Execute iterations in parallel batches
|
||||
@@ -132,15 +134,14 @@ class IterationRuntime:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
child_state.extend(await asyncio.gather(*tasks))
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
@@ -151,4 +152,7 @@ class IterationRuntime:
|
||||
idx += 1
|
||||
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
return {
|
||||
"output": self.result,
|
||||
"__child_state": child_state
|
||||
}
|
||||
|
||||
@@ -67,7 +67,9 @@ class LoopRuntime:
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.state["node_outputs"][self.node_id] = {
|
||||
@@ -76,7 +78,9 @@ class LoopRuntime:
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
loopstate = WorkflowState(
|
||||
@@ -171,10 +175,11 @@ class LoopRuntime:
|
||||
"""
|
||||
loopstate = self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
await self.graph.ainvoke(loopstate)
|
||||
child_state.append(await self.graph.ainvoke(loopstate))
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id]
|
||||
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}
|
||||
|
||||
@@ -10,9 +10,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services import knowledge_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -96,7 +95,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
||||
|
||||
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
@@ -105,7 +104,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
memory: MemoryWindowSetting = Field(
|
||||
...,
|
||||
default_factory=MemoryWindowSetting,
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
@@ -167,7 +168,7 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
@@ -269,12 +270,16 @@ class LLMNode(BaseNode):
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata'):
|
||||
if chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
@@ -288,13 +293,10 @@ class LLMNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
response_metadata=last_meta_data
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
|
||||
@@ -75,6 +75,14 @@ class WorkflowExecution(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
release_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("app_releases.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
|
||||
app_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("apps.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -527,6 +527,7 @@ class AppChatService:
|
||||
conversation_id: uuid.UUID,
|
||||
config: WorkflowConfig,
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
@@ -549,6 +550,7 @@ class AppChatService:
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release_id,
|
||||
)
|
||||
|
||||
async def workflow_chat_stream(
|
||||
@@ -557,6 +559,7 @@ class AppChatService:
|
||||
conversation_id: uuid.UUID,
|
||||
config: WorkflowConfig,
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: str = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
@@ -565,7 +568,7 @@ class AppChatService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
payload = DraftRunRequest(
|
||||
@@ -580,6 +583,7 @@ class AppChatService:
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release_id
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -227,7 +227,6 @@ class AppService:
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
|
||||
|
||||
|
||||
# 3. 检查子 Agent 配置
|
||||
if not multi_agent_config.sub_agents or len(multi_agent_config.sub_agents) == 0:
|
||||
raise BusinessException(
|
||||
@@ -759,8 +758,7 @@ class AppService:
|
||||
)
|
||||
|
||||
# 构建查询条件
|
||||
filters = []
|
||||
filters.append(App.is_active == True)
|
||||
filters = [App.is_active == True]
|
||||
if type:
|
||||
filters.append(App.type == type)
|
||||
if visibility:
|
||||
@@ -875,7 +873,8 @@ class AppService:
|
||||
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active==True).order_by(AgentConfig.updated_at.desc())
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(
|
||||
AgentConfig.updated_at.desc())
|
||||
agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first()
|
||||
now = datetime.datetime.now()
|
||||
|
||||
@@ -948,7 +947,12 @@ class AppService:
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc())
|
||||
stmt = select(AgentConfig).where(
|
||||
AgentConfig.app_id == app_id,
|
||||
AgentConfig.is_active.is_(True)
|
||||
).order_by(
|
||||
AgentConfig.updated_at.desc()
|
||||
)
|
||||
config = self.db.scalars(stmt).first()
|
||||
|
||||
if config:
|
||||
@@ -1200,7 +1204,8 @@ class AppService:
|
||||
default_model_config_id = None
|
||||
|
||||
if app.type == AppType.AGENT:
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc())
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(
|
||||
AgentConfig.updated_at.desc())
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
|
||||
@@ -1237,7 +1242,6 @@ class AppService:
|
||||
|
||||
# 4. 构建配置快照
|
||||
|
||||
|
||||
config = {
|
||||
"model_parameters": model_parameters_to_dict(multi_agent_cfg.model_parameters),
|
||||
"master_agent_id": str(multi_agent_cfg.master_agent_id),
|
||||
@@ -1264,6 +1268,7 @@ class AppService:
|
||||
raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING)
|
||||
|
||||
config = {
|
||||
"id": str(workflow_cfg.id),
|
||||
"nodes": workflow_cfg.nodes,
|
||||
"edges": workflow_cfg.edges,
|
||||
"variables": workflow_cfg.variables,
|
||||
@@ -1457,7 +1462,6 @@ class AppService:
|
||||
BusinessException: 当应用不在指定工作空间或目标工作空间无效时
|
||||
"""
|
||||
|
||||
|
||||
logger.info(
|
||||
"分享应用",
|
||||
extra={
|
||||
@@ -2009,7 +2013,8 @@ def create_app(db: Session, *, user_id: uuid.UUID, workspace_id: uuid.UUID, data
|
||||
return service.create_app(user_id=user_id, workspace_id=workspace_id, data=data)
|
||||
|
||||
|
||||
def update_app(db: Session, *, app_id: uuid.UUID, data: app_schema.AppUpdate, workspace_id: uuid.UUID | None = None) -> App:
|
||||
def update_app(db: Session, *, app_id: uuid.UUID, data: app_schema.AppUpdate,
|
||||
workspace_id: uuid.UUID | None = None) -> App:
|
||||
"""更新应用(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return service.update_app(app_id=app_id, data=data, workspace_id=workspace_id)
|
||||
@@ -2021,12 +2026,15 @@ def delete_app(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None
|
||||
return service.delete_app(app_id=app_id, workspace_id=workspace_id)
|
||||
|
||||
|
||||
def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.AgentConfigUpdate, workspace_id: uuid.UUID | None = None) -> AgentConfig:
|
||||
def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.AgentConfigUpdate,
|
||||
workspace_id: uuid.UUID | None = None) -> AgentConfig:
|
||||
"""更新 Agent 配置(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id)
|
||||
|
||||
def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate, workspace_id: uuid.UUID | None = None) -> WorkflowConfig:
|
||||
|
||||
def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate,
|
||||
workspace_id: uuid.UUID | None = None) -> WorkflowConfig:
|
||||
"""更新 Agent 配置(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return service.update_workflow_config(app_id=app_id, data=data, workspace_id=workspace_id)
|
||||
@@ -2040,6 +2048,7 @@ def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID
|
||||
service = AppService(db)
|
||||
return service.get_agent_config(app_id=app_id, workspace_id=workspace_id)
|
||||
|
||||
|
||||
def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> WorkflowConfig:
|
||||
"""获取 Agent 配置(向后兼容接口)
|
||||
|
||||
@@ -2049,13 +2058,20 @@ def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UU
|
||||
return service.get_workflow_config(app_id=app_id, workspace_id=workspace_id)
|
||||
|
||||
|
||||
def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease:
|
||||
def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,
|
||||
version_name: str, release_notes: Optional[str] = None) -> AppRelease:
|
||||
"""发布应用(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return service.publish(app_id=app_id, publisher_id=publisher_id,version_name = version_name, workspace_id=workspace_id, release_notes=release_notes)
|
||||
return service.publish(app_id=app_id, publisher_id=publisher_id, version_name=version_name,
|
||||
workspace_id=workspace_id, release_notes=release_notes)
|
||||
|
||||
|
||||
def get_current_release(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> Optional[AppRelease]:
|
||||
def get_current_release(
|
||||
db: Session,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None = None
|
||||
) -> Optional[AppRelease]:
|
||||
"""获取当前发布版本(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return service.get_current_release(app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -2156,8 +2172,6 @@ async def draft_run_stream(
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_app_service(
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, AsyncGenerator
|
||||
from typing import Any, Annotated, AsyncGenerator, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
@@ -14,15 +14,14 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
from app.models.conversation_model import Message
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.conversation_repository import MessageRepository
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +35,7 @@ class WorkflowService:
|
||||
self.config_repo = WorkflowConfigRepository(db)
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
self.message_repo = MessageRepository(db)
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
@@ -266,6 +265,7 @@ class WorkflowService:
|
||||
workflow_config_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
trigger_type: str,
|
||||
release_id: uuid.UUID | None = None,
|
||||
triggered_by: uuid.UUID | None = None,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
input_data: dict[str, Any] | None = None
|
||||
@@ -273,6 +273,7 @@ class WorkflowService:
|
||||
"""创建工作流执行记录
|
||||
|
||||
Args:
|
||||
release_id: 应用发布 ID
|
||||
workflow_config_id: 工作流配置 ID
|
||||
app_id: 应用 ID
|
||||
trigger_type: 触发类型
|
||||
@@ -289,6 +290,7 @@ class WorkflowService:
|
||||
execution = WorkflowExecution(
|
||||
workflow_config_id=workflow_config_id,
|
||||
app_id=app_id,
|
||||
release_id=release_id,
|
||||
conversation_id=conversation_id,
|
||||
execution_id=execution_id,
|
||||
trigger_type=trigger_type,
|
||||
@@ -337,6 +339,7 @@ class WorkflowService:
|
||||
self,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
token_usage: int | None = None,
|
||||
output_data: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
error_node_id: str | None = None
|
||||
@@ -346,6 +349,7 @@ class WorkflowService:
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
status: 状态
|
||||
token_usage: token消耗
|
||||
output_data: 输出数据
|
||||
error_message: 错误信息
|
||||
error_node_id: 出错节点 ID
|
||||
@@ -364,6 +368,8 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
execution.status = status
|
||||
if token_usage is not None:
|
||||
execution.token_usage = token_usage
|
||||
if output_data is not None:
|
||||
execution.output_data = convert_uuids_to_str(output_data)
|
||||
if error_message is not None:
|
||||
@@ -414,12 +420,14 @@ class WorkflowService:
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
release_id: uuid.UUID | None = None,
|
||||
):
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
workspace_id:
|
||||
config:
|
||||
release_id: 发布 ID
|
||||
workspace_id:工作空间 ID
|
||||
config: 配置
|
||||
payload:
|
||||
app_id: 应用 ID
|
||||
|
||||
@@ -463,7 +471,8 @@ class WorkflowService:
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
input_data=input_data,
|
||||
release_id=release_id,
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
@@ -507,20 +516,20 @@ class WorkflowService:
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
token_usage = result.get("token_usage", {}) or {}
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result
|
||||
output_data=result,
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
final_messages = result.get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
content=message["content"]
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
else:
|
||||
@@ -562,10 +571,12 @@ class WorkflowService:
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
release_id: Optional[uuid.UUID] = None,
|
||||
):
|
||||
"""运行工作流(流式)
|
||||
|
||||
Args:
|
||||
release_id: 发布id
|
||||
workspace_id:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
@@ -611,7 +622,8 @@ class WorkflowService:
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
input_data=input_data,
|
||||
release_id=release_id,
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
@@ -653,21 +665,21 @@ class WorkflowService:
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
content=message["content"]
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
@@ -784,10 +796,12 @@ class WorkflowService:
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
token_usage = result.get("data").get("token_usage", {}) or {}
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
output_data=result.get("node_outputs", {}),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
@@ -882,13 +896,14 @@ class WorkflowService:
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
token_usage = event.get("data").get("token_usage", {}) or {}
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
|
||||
@@ -120,12 +120,9 @@ def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig:
|
||||
|
||||
def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
|
||||
config_dict = release.config
|
||||
with get_db_read() as db:
|
||||
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
|
||||
source_config_id = source_config.id
|
||||
|
||||
config = WorkflowConfig(
|
||||
id=source_config_id,
|
||||
id=config_dict.get("id"),
|
||||
app_id=release.app_id,
|
||||
nodes=config_dict.get("nodes", []),
|
||||
edges=config_dict.get("edges", []),
|
||||
|
||||
Reference in New Issue
Block a user