From 567624c323d57430e22657aa048eb83f4c573815 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 14 Jan 2026 16:35:46 +0800 Subject: [PATCH] feat(workflow): add session context memory support to LLM nodes --- api/app/controllers/workflow_controller.py | 85 ++++++----- api/app/core/config.py | 37 ++--- api/app/core/workflow/executor.py | 3 +- api/app/core/workflow/nodes/base_node.py | 10 +- api/app/core/workflow/nodes/end/node.py | 52 ++++++- api/app/core/workflow/nodes/llm/config.py | 26 +++- api/app/core/workflow/nodes/llm/node.py | 19 ++- api/app/schemas/app_schema.py | 2 + api/app/services/app_chat_service.py | 161 ++++----------------- api/app/services/workflow_service.py | 79 ++++++++-- api/pyproject.toml | 3 +- 11 files changed, 249 insertions(+), 228 deletions(-) diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index 429aa67e..c6d9ddab 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"]) @router.post("/{app_id}/workflow") @cur_workspace_access_guard() async def create_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - config: WorkflowConfigCreate, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + config: WorkflowConfigCreate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """创建工作流配置 @@ -96,6 +96,7 @@ async def create_workflow_config( msg=f"创建工作流配置失败: {str(e)}" ) + # # @router.get("/{app_id}/workflow") # async def get_workflow_config( @@ -199,10 +200,10 @@ async def create_workflow_config( @router.delete("/{app_id}/workflow") async def delete_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """删除工作流配置 @@ -243,11 +244,11 @@ async def delete_workflow_config( @router.post("/{app_id}/workflow/validate") async def validate_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - for_publish: Annotated[bool, Query(description="是否为发布验证")] = False + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + for_publish: Annotated[bool, Query(description="是否为发布验证")] = False ): """验证工作流配置 @@ -312,12 +313,12 @@ async def validate_workflow_config( @router.get("/{app_id}/workflow/executions") async def get_workflow_executions( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - limit: Annotated[int, Query(ge=1, le=100)] = 50, - offset: Annotated[int, Query(ge=0)] = 0 + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + limit: Annotated[int, Query(ge=1, le=100)] = 50, + offset: Annotated[int, Query(ge=0)] = 0 ): """获取工作流执行记录列表 @@ -365,10 +366,10 @@ async def get_workflow_executions( @router.get("/workflow/executions/{execution_id}") async def get_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """获取工作流执行详情 @@ -417,16 +418,14 @@ async def get_workflow_execution( ) - # ==================== 工作流执行 ==================== - @router.post("/{app_id}/workflow/run") async def run_workflow( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - request: WorkflowExecutionRequest, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + request: WorkflowExecutionRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """执行工作流 @@ -487,22 +486,22 @@ async def run_workflow( """ try: async for event in await service.run_workflow( - app_id=app_id, - input_data=input_data, - triggered_by=current_user.id, - conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, - stream=True + app_id=app_id, + input_data=input_data, + triggered_by=current_user.id, + conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, + stream=True ): # 提取事件类型和数据 event_type = event.get("event", "message") event_data = event.get("data", {}) - + # 转换为标准 SSE 格式(字符串) # event: # data: sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" yield sse_message - + except Exception as e: logger.error(f"流式执行异常: {e}", exc_info=True) # 发送错误事件 @@ -554,10 +553,10 @@ async def run_workflow( @router.post("/workflow/executions/{execution_id}/cancel") async def cancel_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """取消工作流执行 @@ -602,7 +601,7 @@ async def cancel_workflow_execution( except BusinessException as e: logger.warning(f"取消工作流执行失败: {e.message}") - return fail(code=e.error_code, msg=e.message) + return fail(code=e.code, msg=e.message) except Exception as e: logger.error(f"取消工作流执行异常: {e}", exc_info=True) return fail( diff --git a/api/app/core/config.py b/api/app/core/config.py index 573c4283..ff7bf2e1 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -7,17 +7,18 @@ from dotenv import load_dotenv load_dotenv() + class Settings: ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "") - + # Neo4j Configuration (记忆系统数据库) NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687") NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j") NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "") - + # Database configuration (Postgres) DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1") DB_PORT: int = int(os.getenv("DB_PORT", "5432")) @@ -37,7 +38,7 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - + # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200")) @@ -48,7 +49,7 @@ class Settings: ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000")) ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true" ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10")) - + # Xinference configuration XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1") @@ -57,17 +58,17 @@ class Settings: LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true" LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "") LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "") - + # LLM Request Configuration LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0")) LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2")) - + # JWT Token Configuration SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random") ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")) REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")) - + # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" @@ -86,19 +87,19 @@ class Settings: LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") - + # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") # ======================================================================== # Internal Configuration (not in .env, used by application code) # ======================================================================== - + # Superuser settings (internal defaults) FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com") FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin") FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password") - + # Generic File Upload (internal) GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads") ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true" @@ -123,7 +124,7 @@ class Settings: LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5")) LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true" LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true" - + # Sensitive Data Filtering ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true" @@ -142,7 +143,6 @@ class Settings: LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB - # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) @@ -150,15 +150,15 @@ class Settings: HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) - REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) - + REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) + # Memory Cache Regeneration Configuration MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") - + # Tool Management Configuration TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) @@ -167,7 +167,10 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") - + + # workflow config + WORKFLOW_NODE_TIMEOUT: int = os.getenv("WORKFLOW_NODE_TIMEOUT", 600) + def get_memory_output_path(self, filename: str = "") -> str: """ Get the full path for memory module output files. @@ -182,7 +185,7 @@ class Settings: if filename: return str(base_path / filename) return str(base_path) - + def ensure_memory_output_dir(self) -> None: """ Ensure the memory output directory exists. diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 67689935..c048f447 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -74,6 +74,7 @@ class WorkflowExecutor: 初始化的工作流状态 """ user_message = input_data.get("message") or "" + conversation_messages = input_data.get("conv_messages") or [] # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) config_variables_list = self.workflow_config.get("variables") or [] @@ -114,7 +115,7 @@ class WorkflowExecutor: } return { - "messages": [('user', user_message)], + "messages": conversation_messages, "variables": variables, "node_outputs": {}, "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index e3bf36c9..72fd0bb5 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -7,13 +7,13 @@ import asyncio import logging from abc import ABC, abstractmethod -from operator import add from typing import Any -from langchain_core.messages import AnyMessage, AIMessage +from langchain_core.messages import AIMessage from langgraph.config import get_stream_writer from typing_extensions import TypedDict, Annotated +from app.core.config import settings from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: Annotated[list[tuple[str, str]], add] + messages: list[dict[str, str]] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list @@ -154,7 +154,7 @@ class BaseNode(ABC): Returns: 超时时间 """ - return 60 + return settings.WORKFLOW_NODE_TIMEOUT # return self.error_handling.get("timeout", 60) async def run(self, state: WorkflowState) -> dict[str, Any]: @@ -203,6 +203,7 @@ class BaseNode(ABC): # 返回包装后的输出和运行时变量 return { **wrapped_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var @@ -356,6 +357,7 @@ class BaseNode(ABC): # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6195afbd..0cbd9e8e 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -6,7 +6,6 @@ End 节点实现 import logging import re -import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import NodeType @@ -38,7 +37,23 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) else: + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + ]) output = "工作流已完成" # 统计信息(用于日志) @@ -166,6 +181,12 @@ class EndNode(BaseNode): "chunk_index": 1, "is_suffix": False }) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + } + ]) yield {"__final__": True, "result": output} return @@ -176,7 +197,6 @@ class EndNode(BaseNode): source_node_id = edge.get("source") # Check if the source node is an LLM node for node in self.workflow_config.get("nodes", []): - print("="*50) logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: direct_upstream_llm_nodes.append(source_node_id) @@ -216,12 +236,24 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) + # yield completion marker yield {"__final__": True, "result": output} return # Has reference to direct upstream LLM node, only output the part after that reference (suffix) - logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") + logger.info( + f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") # Collect suffix parts suffix_parts = [] @@ -258,6 +290,17 @@ class EndNode(BaseNode): # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": full_output + } + ]) + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") @@ -280,7 +323,8 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!" + f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") # 统计信息 node_outputs = state.get("node_outputs", {}) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 8498fc38..f65d5879 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -11,12 +11,12 @@ class MessageConfig(BaseModel): """消息配置""" role: str = Field( - ..., + default='user', description="消息角色:system, user, assistant" ) content: str = Field( - ..., + default="", description="消息内容,支持模板变量,如:{{ sys.message }}" ) @@ -30,6 +30,23 @@ class MessageConfig(BaseModel): return v.lower() +class MemoryWindowSetting(BaseModel): + enable: bool = Field( + default=False, + description="启用记忆" + ) + + enable_window: bool = Field( + default=False, + description="启用记忆窗口" + ) + + window_size: int = Field( + default=20, + description="记忆窗口大小" + ) + + class LLMNodeConfig(BaseNodeConfig): """LLM 节点配置 @@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig): description="上下文" ) + memory: MemoryWindowSetting = Field( + ..., + description="对话上下文窗口" + ) + # 简单模式 prompt: str | None = Field( default=None, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bfa1b99f..b9ba3d7b 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -85,28 +85,31 @@ class LLMNode(BaseNode): """ # 1. 处理消息格式(优先使用 messages) - messages_config = self.config.get("messages") + messages_config = self.typed_config.messages if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: - role = msg_config.get("role", "user").lower() - content_template = msg_config.get("content", "") + role = msg_config.role.lower() + content_template = msg_config.content content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) # 根据角色创建对应的消息对象 if role == "system": - messages.append(SystemMessage(content=content)) + messages.append({"role": "system", "content": content}) elif role in ["user", "human"]: - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) elif role in ["ai", "assistant"]: - messages.append(AIMessage(content=content)) + messages.append({"role": "user", "content": content}) else: logger.warning(f"未知的消息角色: {role},默认使用 user") - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) + if self.typed_config.memory.enable: + # if self.typed_config.memory.enable_window: + messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -189,7 +192,7 @@ class LLMNode(BaseNode): return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ - {"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} + {"role": msg.get("role"), "content": msg.get("content", "")} for msg in prompt_or_messages ] if isinstance(prompt_or_messages, list) else None, "config": { diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 3c00e5a0..35d2e424 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -41,6 +41,7 @@ class ToolConfig(BaseModel): tool_id: Optional[str] = Field(default=None, description="工具ID") operation: Optional[str] = Field(default=None, description="工具特定配置") + class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -348,6 +349,7 @@ class AppChatRequest(BaseModel): variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") + class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 56400c92..0065c64b 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db, get_db_context from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig +from app.schemas import DraftRunRequest from app.services.tool_service import ToolService from app.repositories.tool_repository import ToolRepository from app.db import get_db @@ -59,7 +60,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -210,7 +211,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -511,7 +512,6 @@ class AppChatService: } ) - except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("多 Agent 流式聊天被中断") @@ -537,83 +537,19 @@ class AppChatService: ) -> Dict[str, Any]: """聊天(非流式)""" workflow_service = WorkflowService(self.db) - - input_data = {"message":message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id + ) + return await workflow_service.run( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, ) - - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - result = await execute_workflow( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ) - - # 更新执行结果 - if result.get("status") == "completed": - workflow_service.update_execution_status( - execution.execution_id, - "completed", - output_data=result.get("node_outputs", {}) - ) - else: - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=result.get("error") - ) - - # 返回增强的响应结构 - return { - "execution_id": execution.execution_id, - "status": result.get("status"), - "output": result.get("output"), # 最终输出(字符串) - "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) - "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID - "error_message": result.get("error"), - "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") - } - - except Exception as e: - logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - raise BusinessException( - code=BizCode.INTERNAL_ERROR, - message=f"工作流执行失败: {str(e)}" - ) async def workflow_chat_stream( self, @@ -632,62 +568,21 @@ class AppChatService: ) -> AsyncGenerator[str, None]: """聊天(流式)""" workflow_service = WorkflowService(self.db) - input_data = {"message": message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id ) + async for event in workflow_service.run_stream( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, + ): + yield event - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 流式执行工作流 - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in workflow_service._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ): - # 直接转发 executor 的事件(已经是正确的格式) - yield event - - except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - # 发送错误事件 - yield { - "event": "error", - "data": { - "execution_id": execution.execution_id, - "error": str(e) - } - } # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 7d3c784f..f9988352 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -2,12 +2,11 @@ 工作流服务层 """ import datetime -import json import logging import uuid -import datetime from typing import Any, Annotated, AsyncGenerator +from deprecated import deprecated from fastapi import Depends from sqlalchemy.orm import Session @@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config from app.db import get_db, get_db_context from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories.conversation_repository import MessageRepository +from app.models.conversation_model import Message from app.repositories.end_user_repository import EndUserRepository -from app.services.multi_agent_service import convert_uuids_to_str from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) from app.schemas import DraftRunRequest -from app.utils.sse_utils import format_sse_message +from app.services.multi_agent_service import convert_uuids_to_str logger = logging.getLogger(__name__) @@ -37,6 +37,7 @@ class WorkflowService: self.config_repo = WorkflowConfigRepository(db) self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) + self.message_repo = MessageRepository(db) # ==================== 配置管理 ==================== @@ -418,14 +419,13 @@ class WorkflowService: """运行工作流 Args: + workspace_id: + config: + payload: app_id: 应用 ID - input_data: 输入数据(包含 message 和 variables) - triggered_by: 触发用户 ID - conversation_id: 会话 ID(可选) - stream: 是否流式返回 Returns: - 执行结果(非流式)或生成器(流式) + 执行结果(非流式) Raises: BusinessException: 配置不存在或执行失败时抛出 @@ -438,7 +438,8 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id} # 转换 user_id 为 UUID triggered_by_uuid = None @@ -461,7 +462,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -500,8 +501,11 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, @@ -517,6 +521,17 @@ class WorkflowService: "completed", output_data=result ) + final_messages = result.get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: self.update_execution_status( execution.execution_id, @@ -529,6 +544,7 @@ class WorkflowService: "execution_id": execution.execution_id, "status": result.get("status"), "variables": result.get("variables"), + "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID @@ -559,6 +575,7 @@ class WorkflowService: """运行工作流(流式) Args: + workspace_id: app_id: 应用 ID payload: 请求对象(包含 message, variables, conversation_id 等) config: 存储类型(可选) @@ -601,7 +618,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -638,17 +655,46 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + from app.core.workflow.executor import execute_workflow_stream - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in self._run_workflow_stream( + async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=end_user_id ): - # 直接转发 executor 的事件(已经是正确的格式) + if event.get("event") == "workflow_end": + + status = event.get("data", {}).get("status") + if status == "completed": + self.update_execution_status( + execution.execution_id, + "completed", + output_data=event.get("data") + ) + final_messages = event.get("data", {}).get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") + elif status == "failed": + self.update_execution_status( + execution.execution_id, + "failed", + output_data=event.get("data") + ) + else: + logger.error(f"unexpect workflow run status, status: {status}") yield event except Exception as e: @@ -667,6 +713,8 @@ class WorkflowService: } } + @deprecated(reason="This method is deprecated. " + "Please use WorkflowService.run / run_stream instead.") async def run_workflow( self, app_id: uuid.UUID, @@ -819,6 +867,7 @@ class WorkflowService: return clean_value(event) + @deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.") async def _run_workflow_stream( self, workflow_config: dict[str, Any], diff --git a/api/pyproject.toml b/api/pyproject.toml index 2dcc706d..6da684de 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -136,7 +136,8 @@ dependencies = [ "markdown-to-json==2.1.1", "valkey==6.0.2", "python-calamine>=0.4.0", - "xlrd==2.0.2" + "xlrd==2.0.2", + "deprecated>=1.3.1", ] [tool.pytest.ini_options]