diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index a0df7d67..698f061d 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -29,9 +29,9 @@ logger = get_business_logger() @router.post("", summary="创建应用(可选创建 Agent 配置)") @cur_workspace_access_guard() def create_app( - payload: app_schema.AppCreate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + payload: app_schema.AppCreate, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload) @@ -41,15 +41,15 @@ def create_app( @router.get("", summary="应用列表(分页)") @cur_workspace_access_guard() def list_apps( - type: str | None = None, - visibility: str | None = None, - status: str | None = None, - search: str | None = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + type: str | None = None, + visibility: str | None = None, + status: str | None = None, + search: str | None = None, + include_shared: bool = True, + page: int = 1, + pagesize: int = 10, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """列出应用 @@ -75,12 +75,13 @@ def list_apps( meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) return success(data=PageData(page=meta, items=items)) + @router.get("/{app_id}", summary="获取应用详情") @cur_workspace_access_guard() def get_app( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """获取应用详细信息 @@ -99,10 +100,10 @@ def get_app( @router.put("/{app_id}", summary="更新应用基本信息") @cur_workspace_access_guard() def update_app( - app_id: uuid.UUID, - payload: app_schema.AppUpdate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + payload: app_schema.AppUpdate, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id) @@ -112,9 +113,9 @@ def update_app( @router.delete("/{app_id}", summary="删除应用") @cur_workspace_access_guard() def delete_app( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """删除应用 @@ -141,10 +142,10 @@ def delete_app( @router.post("/{app_id}/copy", summary="复制应用") @cur_workspace_access_guard() def copy_app( - app_id: uuid.UUID, - new_name: Optional[str] = None, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + new_name: Optional[str] = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """复制应用(包括基础信息和配置) @@ -178,10 +179,10 @@ def copy_app( @router.put("/{app_id}/config", summary="更新 Agent 配置") @cur_workspace_access_guard() def update_agent_config( - app_id: uuid.UUID, - payload: app_schema.AgentConfigUpdate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + payload: app_schema.AgentConfigUpdate, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) @@ -192,9 +193,9 @@ def update_agent_config( @router.get("/{app_id}/config", summary="获取 Agent 配置") @cur_workspace_access_guard() def get_agent_config( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) @@ -206,10 +207,10 @@ def get_agent_config( @router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") @cur_workspace_access_guard() def publish_app( - app_id: uuid.UUID, - payload: app_schema.PublishRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + payload: app_schema.PublishRequest, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id release = app_service.publish( @@ -217,7 +218,7 @@ def publish_app( app_id=app_id, publisher_id=current_user.id, workspace_id=workspace_id, - version_name = payload.version_name, + version_name=payload.version_name, release_notes=payload.release_notes ) return success(data=app_schema.AppRelease.model_validate(release)) @@ -226,9 +227,9 @@ def publish_app( @router.get("/{app_id}/release", summary="获取当前发布版本") @cur_workspace_access_guard() def get_current_release( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id) @@ -240,9 +241,9 @@ def get_current_release( @router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)") @cur_workspace_access_guard() def list_releases( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id) @@ -253,10 +254,10 @@ def list_releases( @router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本") @cur_workspace_access_guard() def rollback( - app_id: uuid.UUID, - version: int, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + version: int, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): workspace_id = current_user.current_workspace_id release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id) @@ -266,10 +267,10 @@ def rollback( @router.post("/{app_id}/share", summary="分享应用到其他工作空间") @cur_workspace_access_guard() def share_app( - app_id: uuid.UUID, - payload: app_schema.AppShareCreate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + payload: app_schema.AppShareCreate, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """分享应用到其他工作空间 @@ -294,10 +295,10 @@ def share_app( @router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享") @cur_workspace_access_guard() def unshare_app( - app_id: uuid.UUID, - target_workspace_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + target_workspace_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """取消应用分享 @@ -318,9 +319,9 @@ def unshare_app( @router.get("/{app_id}/shares", summary="列出应用的分享记录") @cur_workspace_access_guard() def list_app_shares( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """列出应用的所有分享记录 @@ -337,14 +338,15 @@ def list_app_shares( data = [app_schema.AppShare.model_validate(s) for s in shares] return success(data=data) + @router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)") @cur_workspace_access_guard() async def draft_run( - app_id: uuid.UUID, - payload: app_schema.DraftRunRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), - workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None + app_id: uuid.UUID, + payload: app_schema.DraftRunRequest, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), + workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None ): """ 试运行 Agent,使用当前的草稿配置(未发布的配置) @@ -361,7 +363,7 @@ async def draft_run( workspace_id=workspace_id, user=current_user ) - if storage_type is None: + if storage_type is None: storage_type = 'neo4j' user_rag_memory_id = '' if workspace_id: @@ -371,10 +373,9 @@ async def draft_run( name="USER_RAG_MERORY", workspace_id=workspace_id ) - if knowledge: + if knowledge: user_rag_memory_id = str(knowledge.id) - # 提前验证和准备(在流式响应开始前完成) from app.services.app_service import AppService from app.services.multi_agent_service import MultiAgentService @@ -396,11 +397,11 @@ async def draft_run( # 处理会话ID(创建或验证) conversation_id = await draft_service._ensure_conversation( - conversation_id=payload.conversation_id, - app_id=app_id, - workspace_id=workspace_id, - user_id=payload.user_id - ) + conversation_id=payload.conversation_id, + app_id=app_id, + workspace_id=workspace_id, + user_id=payload.user_id + ) payload.conversation_id = conversation_id if app.type == AppType.AGENT: @@ -424,17 +425,16 @@ async def draft_run( if payload.stream: async def event_generator(): - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + agent_config=agent_cfg, + model_config=model_config, + message=payload.message, + workspace_id=workspace_id, + conversation_id=payload.conversation_id, + user_id=payload.user_id or str(current_user.id), + variables=payload.variables, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -528,10 +528,10 @@ async def draft_run( # 调用多智能体服务的流式方法 async for event in multiservice.run_stream( - app_id=app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + app_id=app_id, + request=multi_agent_request, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -571,7 +571,7 @@ async def draft_run( data=result, msg="多 Agent 任务执行成功" ) - elif app.type == AppType.WORKFLOW: #工作流 + elif app.type == AppType.WORKFLOW: # 工作流 config = workflow_service.check_config(app_id) # 3. 流式返回 if payload.stream: @@ -592,7 +592,7 @@ async def draft_run( data: """ import json - + # 调用工作流服务的流式方法 async for event in workflow_service.run_stream( app_id=app_id, @@ -603,7 +603,7 @@ async def draft_run( # 提取事件类型和数据 event_type = event.get("event", "message") event_data = event.get("data", {}) - + # 转换为标准 SSE 格式(字符串) sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" yield sse_message @@ -643,14 +643,13 @@ async def draft_run( ) - @router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行") @cur_workspace_access_guard() async def draft_run_compare( - app_id: uuid.UUID, - payload: app_schema.DraftRunCompareRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), + app_id: uuid.UUID, + payload: app_schema.DraftRunCompareRequest, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), ): """ 多模型对比试运行 @@ -675,7 +674,7 @@ async def draft_run_compare( workspace_id=workspace_id, user=current_user ) - if storage_type is None: + if storage_type is None: storage_type = 'neo4j' user_rag_memory_id = '' if workspace_id: @@ -684,7 +683,7 @@ async def draft_run_compare( name="USER_RAG_MERORY", workspace_id=workspace_id ) - if knowledge: + if knowledge: user_rag_memory_id = str(knowledge.id) logger.info( @@ -748,19 +747,19 @@ async def draft_run_compare( from app.services.draft_run_service import DraftRunService draft_service = DraftRunService(db) async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=True, - memory=True, - parallel=payload.parallel, - timeout=payload.timeout or 60 + agent_config=agent_cfg, + models=model_configs, + message=payload.message, + workspace_id=workspace_id, + conversation_id=payload.conversation_id, + user_id=payload.user_id or str(current_user.id), + variables=payload.variables, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + web_search=True, + memory=True, + parallel=payload.parallel, + timeout=payload.timeout or 60 ): yield event @@ -822,15 +821,15 @@ async def get_workflow_config( # 配置总是存在(不存在时返回默认模板) return success(data=WorkflowConfigSchema.model_validate(cfg)) + @router.put("/{app_id}/workflow", summary="更新 Workflow 配置") @cur_workspace_access_guard() async def update_workflow_config( - app_id: uuid.UUID, - payload: WorkflowConfigUpdate, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)] + app_id: uuid.UUID, + payload: WorkflowConfigUpdate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] ): workspace_id = current_user.current_workspace_id cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) return success(data=WorkflowConfigSchema.model_validate(cfg)) - diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 6e9c2c51..4d31efaa 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig +from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig __all__ = [ @@ -45,6 +46,8 @@ __all__ = [ "ParameterExtractorNodeConfig", "LoopNodeConfig", "IterationNodeConfig", - "QuestionClassifierNodeConfig" - "ToolNodeConfig" + "QuestionClassifierNodeConfig", + "ToolNodeConfig", + "MemoryReadNodeConfig", + "MemoryWriteNodeConfig" ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index b1c9d687..fbbbf845 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -22,6 +22,8 @@ class NodeType(StrEnum): ITERATION = "iteration" CYCLE_START = "cycle-start" BREAK = "break" + MEMORY_READ = "memory-read" + MEMORY_WRITE = "memory-write" class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/memory/__init__.py b/api/app/core/workflow/nodes/memory/__init__.py new file mode 100644 index 00000000..c22c2816 --- /dev/null +++ b/api/app/core/workflow/nodes/memory/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig +from app.core.workflow.nodes.memory.node import MemoryReadNode, MemoryWriteNode + +__all__ = ["MemoryReadNodeConfig", "MemoryReadNode", "MemoryWriteNodeConfig", "MemoryWriteNode"] diff --git a/api/app/core/workflow/nodes/memory/config.py b/api/app/core/workflow/nodes/memory/config.py new file mode 100644 index 00000000..317dc507 --- /dev/null +++ b/api/app/core/workflow/nodes/memory/config.py @@ -0,0 +1,31 @@ +import uuid + +from pydantic import Field +from typing import Literal + +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class MemoryReadNodeConfig(BaseNodeConfig): + message: str = Field( + ... + ) + + config_id: str = Field( + ... + ) + + search_switch: str = Field( + "0", + description="Search mode: 0=verify, 1=direct, 2=context" + ) + + +class MemoryWriteNodeConfig(BaseNodeConfig): + message: str = Field( + ... + ) + + config_id: str = Field( + ... + ) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py new file mode 100644 index 00000000..09c9fc68 --- /dev/null +++ b/api/app/core/workflow/nodes/memory/node.py @@ -0,0 +1,59 @@ +from typing import Any + +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig +from app.db import get_db_read, get_db_context +from app.services.memory_agent_service import MemoryAgentService + + +class MemoryReadNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = MemoryReadNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + with get_db_read() as db: + workspace_id = self.get_variable('sys.workspace_id', state) + end_user_id = self.get_variable("sys.user_id", state) + + if not workspace_id: + raise RuntimeError("Workspace id is required") + if not end_user_id: + raise RuntimeError("End user id is required") + + return await MemoryAgentService().read_memory( + group_id=end_user_id, + message=self.typed_config.message, + config_id=self.typed_config.config_id, + search_switch=self.typed_config.search_switch, + history=[], + db=db, + storage_type="neo4j", + user_rag_memory_id="" + ) + + +class MemoryWriteNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = MemoryWriteNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + with get_db_context() as db: + workspace_id = self.get_variable('sys.workspace_id', state) + end_user_id = self.get_variable("sys.user_id", state) + + if not workspace_id: + raise RuntimeError("Workspace id is required") + if not end_user_id: + raise RuntimeError("End user id is required") + + return await MemoryAgentService().write_memory( + group_id=end_user_id, + message=self.typed_config.message, + config_id=self.typed_config.config_id, + db=db, + storage_type="neo4j", + user_rag_memory_id="" + ) diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 497529e5..9fca8d7a 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -18,6 +18,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.jinja_render import JinjaRenderNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode @@ -46,7 +47,9 @@ WorkflowNode = Union[ BreakNode, ParameterExtractorNode, QuestionClassifierNode, - ToolNode + ToolNode, + MemoryReadNode, + MemoryWriteNode ] @@ -76,6 +79,8 @@ class NodeFactory: NodeType.BREAK: BreakNode, NodeType.CYCLE_START: StartNode, NodeType.TOOL: ToolNode, + NodeType.MEMORY_READ: MemoryReadNode, + NodeType.MEMORY_WRITE: MemoryWriteNode, } @classmethod diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 917a40f9..d96efdf7 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -14,8 +14,9 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config -from app.db import get_db +from app.db import get_db, get_db_context from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories.end_user_repository import EndUserRepository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -480,13 +481,21 @@ class WorkflowService: try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") + with get_db_context() as db: + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, + other_id=payload.user_id, + original_user_id=payload.user_id # Save original user_id to other_id + ) + end_user_id = str(new_end_user.id) result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=end_user_id ) # 更新执行结果 @@ -599,6 +608,14 @@ class WorkflowService: try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") + with get_db_context() as db: + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, + other_id=payload.user_id, + original_user_id=payload.user_id # Save original user_id to other_id + ) + end_user_id = str(new_end_user.id) # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( @@ -606,7 +623,7 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=end_user_id ): # 直接转发 executor 的事件(已经是正确的格式) yield event