feat(workflow): add memory read and write node (#24)

This commit is contained in:
Eternity
2026-01-05 15:57:04 +08:00
committed by GitHub
parent ab0e465760
commit 78207aca34
8 changed files with 240 additions and 120 deletions

View File

@@ -29,9 +29,9 @@ logger = get_business_logger()
@router.post("", summary="创建应用(可选创建 Agent 配置)") @router.post("", summary="创建应用(可选创建 Agent 配置)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def create_app( def create_app(
payload: app_schema.AppCreate, payload: app_schema.AppCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload) app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
@@ -41,15 +41,15 @@ def create_app(
@router.get("", summary="应用列表(分页)") @router.get("", summary="应用列表(分页)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def list_apps( def list_apps(
type: str | None = None, type: str | None = None,
visibility: str | None = None, visibility: str | None = None,
status: str | None = None, status: str | None = None,
search: str | None = None, search: str | None = None,
include_shared: bool = True, include_shared: bool = True,
page: int = 1, page: int = 1,
pagesize: int = 10, pagesize: int = 10,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""列出应用 """列出应用
@@ -75,12 +75,13 @@ def list_apps(
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items)) return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}", summary="获取应用详情") @router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def get_app( def get_app(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""获取应用详细信息 """获取应用详细信息
@@ -99,10 +100,10 @@ def get_app(
@router.put("/{app_id}", summary="更新应用基本信息") @router.put("/{app_id}", summary="更新应用基本信息")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def update_app( def update_app(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.AppUpdate, payload: app_schema.AppUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id) app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
@@ -112,9 +113,9 @@ def update_app(
@router.delete("/{app_id}", summary="删除应用") @router.delete("/{app_id}", summary="删除应用")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def delete_app( def delete_app(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""删除应用 """删除应用
@@ -141,10 +142,10 @@ def delete_app(
@router.post("/{app_id}/copy", summary="复制应用") @router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def copy_app( def copy_app(
app_id: uuid.UUID, app_id: uuid.UUID,
new_name: Optional[str] = None, new_name: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""复制应用(包括基础信息和配置) """复制应用(包括基础信息和配置)
@@ -178,10 +179,10 @@ def copy_app(
@router.put("/{app_id}/config", summary="更新 Agent 配置") @router.put("/{app_id}/config", summary="更新 Agent 配置")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def update_agent_config( def update_agent_config(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.AgentConfigUpdate, payload: app_schema.AgentConfigUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
@@ -192,9 +193,9 @@ def update_agent_config(
@router.get("/{app_id}/config", summary="获取 Agent 配置") @router.get("/{app_id}/config", summary="获取 Agent 配置")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def get_agent_config( def get_agent_config(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
@@ -206,10 +207,10 @@ def get_agent_config(
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") @router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def publish_app( def publish_app(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.PublishRequest, payload: app_schema.PublishRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
release = app_service.publish( release = app_service.publish(
@@ -217,7 +218,7 @@ def publish_app(
app_id=app_id, app_id=app_id,
publisher_id=current_user.id, publisher_id=current_user.id,
workspace_id=workspace_id, workspace_id=workspace_id,
version_name = payload.version_name, version_name=payload.version_name,
release_notes=payload.release_notes release_notes=payload.release_notes
) )
return success(data=app_schema.AppRelease.model_validate(release)) return success(data=app_schema.AppRelease.model_validate(release))
@@ -226,9 +227,9 @@ def publish_app(
@router.get("/{app_id}/release", summary="获取当前发布版本") @router.get("/{app_id}/release", summary="获取当前发布版本")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def get_current_release( def get_current_release(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id) release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
@@ -240,9 +241,9 @@ def get_current_release(
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)") @router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def list_releases( def list_releases(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id) releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
@@ -253,10 +254,10 @@ def list_releases(
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本") @router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def rollback( def rollback(
app_id: uuid.UUID, app_id: uuid.UUID,
version: int, version: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id) release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
@@ -266,10 +267,10 @@ def rollback(
@router.post("/{app_id}/share", summary="分享应用到其他工作空间") @router.post("/{app_id}/share", summary="分享应用到其他工作空间")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def share_app( def share_app(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.AppShareCreate, payload: app_schema.AppShareCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""分享应用到其他工作空间 """分享应用到其他工作空间
@@ -294,10 +295,10 @@ def share_app(
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享") @router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def unshare_app( def unshare_app(
app_id: uuid.UUID, app_id: uuid.UUID,
target_workspace_id: uuid.UUID, target_workspace_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""取消应用分享 """取消应用分享
@@ -318,9 +319,9 @@ def unshare_app(
@router.get("/{app_id}/shares", summary="列出应用的分享记录") @router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def list_app_shares( def list_app_shares(
app_id: uuid.UUID, app_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""列出应用的所有分享记录 """列出应用的所有分享记录
@@ -337,14 +338,15 @@ def list_app_shares(
data = [app_schema.AppShare.model_validate(s) for s in shares] data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data) return success(data=data)
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置") @router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def draft_run( async def draft_run(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.DraftRunRequest, payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
): ):
""" """
试运行 Agent使用当前的草稿配置未发布的配置 试运行 Agent使用当前的草稿配置未发布的配置
@@ -374,7 +376,6 @@ async def draft_run(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成) # 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
@@ -396,11 +397,11 @@ async def draft_run(
# 处理会话ID创建或验证 # 处理会话ID创建或验证
conversation_id = await draft_service._ensure_conversation( conversation_id = await draft_service._ensure_conversation(
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
app_id=app_id, app_id=app_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=payload.user_id user_id=payload.user_id
) )
payload.conversation_id = conversation_id payload.conversation_id = conversation_id
if app.type == AppType.AGENT: if app.type == AppType.AGENT:
@@ -424,17 +425,16 @@ async def draft_run(
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
async for event in draft_service.run_stream( async for event in draft_service.run_stream(
agent_config=agent_cfg, agent_config=agent_cfg,
model_config=model_config, model_config=model_config,
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id or str(current_user.id),
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event
@@ -528,10 +528,10 @@ async def draft_run(
# 调用多智能体服务的流式方法 # 调用多智能体服务的流式方法
async for event in multiservice.run_stream( async for event in multiservice.run_stream(
app_id=app_id, app_id=app_id,
request=multi_agent_request, request=multi_agent_request,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event
@@ -571,7 +571,7 @@ async def draft_run(
data=result, data=result,
msg="多 Agent 任务执行成功" msg="多 Agent 任务执行成功"
) )
elif app.type == AppType.WORKFLOW: #工作流 elif app.type == AppType.WORKFLOW: # 工作流
config = workflow_service.check_config(app_id) config = workflow_service.check_config(app_id)
# 3. 流式返回 # 3. 流式返回
if payload.stream: if payload.stream:
@@ -643,14 +643,13 @@ async def draft_run(
) )
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行") @router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def draft_run_compare( async def draft_run_compare(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: app_schema.DraftRunCompareRequest, payload: app_schema.DraftRunCompareRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
""" """
多模型对比试运行 多模型对比试运行
@@ -748,19 +747,19 @@ async def draft_run_compare(
from app.services.draft_run_service import DraftRunService from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db) draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream( async for event in draft_service.run_compare_stream(
agent_config=agent_cfg, agent_config=agent_cfg,
models=model_configs, models=model_configs,
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id or str(current_user.id),
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=True, web_search=True,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60 timeout=payload.timeout or 60
): ):
yield event yield event
@@ -822,15 +821,15 @@ async def get_workflow_config(
# 配置总是存在(不存在时返回默认模板) # 配置总是存在(不存在时返回默认模板)
return success(data=WorkflowConfigSchema.model_validate(cfg)) return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置") @router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def update_workflow_config( async def update_workflow_config(
app_id: uuid.UUID, app_id: uuid.UUID,
payload: WorkflowConfigUpdate, payload: WorkflowConfigUpdate,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)] current_user: Annotated[User, Depends(get_current_user)]
): ):
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg)) return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [ __all__ = [
@@ -45,6 +46,8 @@ __all__ = [
"ParameterExtractorNodeConfig", "ParameterExtractorNodeConfig",
"LoopNodeConfig", "LoopNodeConfig",
"IterationNodeConfig", "IterationNodeConfig",
"QuestionClassifierNodeConfig" "QuestionClassifierNodeConfig",
"ToolNodeConfig" "ToolNodeConfig",
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig"
] ]

View File

@@ -22,6 +22,8 @@ class NodeType(StrEnum):
ITERATION = "iteration" ITERATION = "iteration"
CYCLE_START = "cycle-start" CYCLE_START = "cycle-start"
BREAK = "break" BREAK = "break"
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.memory.node import MemoryReadNode, MemoryWriteNode
__all__ = ["MemoryReadNodeConfig", "MemoryReadNode", "MemoryWriteNodeConfig", "MemoryWriteNode"]

View File

@@ -0,0 +1,31 @@
import uuid
from pydantic import Field
from typing import Literal
from app.core.workflow.nodes.base_config import BaseNodeConfig
class MemoryReadNodeConfig(BaseNodeConfig):
message: str = Field(
...
)
config_id: str = Field(
...
)
search_switch: str = Field(
"0",
description="Search mode: 0=verify, 1=direct, 2=context"
)
class MemoryWriteNodeConfig(BaseNodeConfig):
message: str = Field(
...
)
config_id: str = Field(
...
)

View File

@@ -0,0 +1,59 @@
from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context
from app.services.memory_agent_service import MemoryAgentService
class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryReadNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
group_id=end_user_id,
message=self.typed_config.message,
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)
class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self.typed_config.message,
config_id=self.typed_config.config_id,
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)

View File

@@ -18,6 +18,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.jinja_render import JinjaRenderNode from app.core.workflow.nodes.jinja_render import JinjaRenderNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
@@ -46,7 +47,9 @@ WorkflowNode = Union[
BreakNode, BreakNode,
ParameterExtractorNode, ParameterExtractorNode,
QuestionClassifierNode, QuestionClassifierNode,
ToolNode ToolNode,
MemoryReadNode,
MemoryWriteNode
] ]
@@ -76,6 +79,8 @@ class NodeFactory:
NodeType.BREAK: BreakNode, NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode, NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode, NodeType.TOOL: ToolNode,
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
} }
@classmethod @classmethod

View File

@@ -14,8 +14,9 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.db import get_db from app.db import get_db, get_db_context
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
WorkflowExecutionRepository, WorkflowExecutionRepository,
@@ -480,13 +481,21 @@ class WorkflowService:
try: try:
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
with get_db_context() as db:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=payload.user_id,
original_user_id=payload.user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
result = await execute_workflow( result = await execute_workflow(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id user_id=end_user_id
) )
# 更新执行结果 # 更新执行结果
@@ -599,6 +608,14 @@ class WorkflowService:
try: try:
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
with get_db_context() as db:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=payload.user_id,
original_user_id=payload.user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件) # 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in self._run_workflow_stream( async for event in self._run_workflow_stream(
@@ -606,7 +623,7 @@ class WorkflowService:
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id user_id=end_user_id
): ):
# 直接转发 executor 的事件(已经是正确的格式) # 直接转发 executor 的事件(已经是正确的格式)
yield event yield event