diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index 429aa67e..c6d9ddab 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"]) @router.post("/{app_id}/workflow") @cur_workspace_access_guard() async def create_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - config: WorkflowConfigCreate, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + config: WorkflowConfigCreate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """创建工作流配置 @@ -96,6 +96,7 @@ async def create_workflow_config( msg=f"创建工作流配置失败: {str(e)}" ) + # # @router.get("/{app_id}/workflow") # async def get_workflow_config( @@ -199,10 +200,10 @@ async def create_workflow_config( @router.delete("/{app_id}/workflow") async def delete_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """删除工作流配置 @@ -243,11 +244,11 @@ async def delete_workflow_config( @router.post("/{app_id}/workflow/validate") async def validate_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - for_publish: Annotated[bool, Query(description="是否为发布验证")] = False + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + for_publish: Annotated[bool, Query(description="是否为发布验证")] = False ): """验证工作流配置 @@ -312,12 +313,12 @@ async def validate_workflow_config( @router.get("/{app_id}/workflow/executions") async def get_workflow_executions( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - limit: Annotated[int, Query(ge=1, le=100)] = 50, - offset: Annotated[int, Query(ge=0)] = 0 + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + limit: Annotated[int, Query(ge=1, le=100)] = 50, + offset: Annotated[int, Query(ge=0)] = 0 ): """获取工作流执行记录列表 @@ -365,10 +366,10 @@ async def get_workflow_executions( @router.get("/workflow/executions/{execution_id}") async def get_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """获取工作流执行详情 @@ -417,16 +418,14 @@ async def get_workflow_execution( ) - # ==================== 工作流执行 ==================== - @router.post("/{app_id}/workflow/run") async def run_workflow( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - request: WorkflowExecutionRequest, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + request: WorkflowExecutionRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """执行工作流 @@ -487,22 +486,22 @@ async def run_workflow( """ try: async for event in await service.run_workflow( - app_id=app_id, - input_data=input_data, - triggered_by=current_user.id, - conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, - stream=True + app_id=app_id, + input_data=input_data, + triggered_by=current_user.id, + conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, + stream=True ): # 提取事件类型和数据 event_type = event.get("event", "message") event_data = event.get("data", {}) - + # 转换为标准 SSE 格式(字符串) # event: # data: sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" yield sse_message - + except Exception as e: logger.error(f"流式执行异常: {e}", exc_info=True) # 发送错误事件 @@ -554,10 +553,10 @@ async def run_workflow( @router.post("/workflow/executions/{execution_id}/cancel") async def cancel_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """取消工作流执行 @@ -602,7 +601,7 @@ async def cancel_workflow_execution( except BusinessException as e: logger.warning(f"取消工作流执行失败: {e.message}") - return fail(code=e.error_code, msg=e.message) + return fail(code=e.code, msg=e.message) except Exception as e: logger.error(f"取消工作流执行异常: {e}", exc_info=True) return fail( diff --git a/api/app/core/config.py b/api/app/core/config.py index 573c4283..5f4f91c4 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -7,17 +7,18 @@ from dotenv import load_dotenv load_dotenv() + class Settings: ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "") - + # Neo4j Configuration (记忆系统数据库) NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687") NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j") NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "") - + # Database configuration (Postgres) DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1") DB_PORT: int = int(os.getenv("DB_PORT", "5432")) @@ -37,7 +38,7 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - + # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200")) @@ -48,7 +49,7 @@ class Settings: ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000")) ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true" ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10")) - + # Xinference configuration XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1") @@ -57,17 +58,17 @@ class Settings: LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true" LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "") LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "") - + # LLM Request Configuration LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0")) LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2")) - + # JWT Token Configuration SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random") ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")) REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")) - + # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" @@ -86,19 +87,19 @@ class Settings: LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") - + # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") # ======================================================================== # Internal Configuration (not in .env, used by application code) # ======================================================================== - + # Superuser settings (internal defaults) FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com") FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin") FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password") - + # Generic File Upload (internal) GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads") ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true" @@ -123,7 +124,7 @@ class Settings: LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5")) LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true" LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true" - + # Sensitive Data Filtering ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true" @@ -142,7 +143,6 @@ class Settings: LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB - # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) @@ -150,15 +150,15 @@ class Settings: HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) - REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) - + REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) + # Memory Cache Regeneration Configuration MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") - + # Tool Management Configuration TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) @@ -167,7 +167,10 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") - + + # workflow config + WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) + def get_memory_output_path(self, filename: str = "") -> str: """ Get the full path for memory module output files. @@ -182,7 +185,7 @@ class Settings: if filename: return str(base_path / filename) return str(base_path) - + def ensure_memory_output_dir(self) -> None: """ Ensure the memory output directory exists. diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py index 6d5012f1..0f306572 100644 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py @@ -425,15 +425,9 @@ async def Input_Summary( try: # Extract services from context - template_service = get_context_resource(ctx, "template_service") session_service = get_context_resource(ctx, "session_service") search_service = get_context_resource(ctx, "search_service") - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - # Resolve session ID sessionid = Resolve_username(usermessages) or "" sessionid = sessionid.replace('call_id_', '') @@ -539,31 +533,11 @@ async def Input_Summary( ) retrieve_info, question, raw_results = "", query, [] + # Return retrieved information directly without LLM processing + # Use the raw retrieved info as the answer + aimessages = retrieve_info if retrieve_info else "信息不足,无法回答" - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='input_summary', - query=query, - history=history, - retrieve_info=retrieve_info - ) - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - aimessages = structured.data.query_answer or "信息不足,无法回答" - except Exception as e: - logger.error( - f"Input_Summary: response_structured failed, using default answer: {e}", - exc_info=True - ) - aimessages = "信息不足,无法回答" - - logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}") + logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...") # Emit intermediate output for frontend return { diff --git a/api/app/core/tools/custom/schema_parser.py b/api/app/core/tools/custom/schema_parser.py index a22e2cfa..ea261dba 100644 --- a/api/app/core/tools/custom/schema_parser.py +++ b/api/app/core/tools/custom/schema_parser.py @@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger logger = get_business_logger() -# 为了兼容性,创建别名 -# SchemaParser = OpenAPISchemaParser = None - class OpenAPISchemaParser: """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" @@ -213,7 +210,9 @@ class OpenAPISchemaParser: if not isinstance(operation, dict): continue - + + summary = operation.get("summary", "") + # 生成操作ID operation_id = operation.get("operationId") if not operation_id: @@ -223,7 +222,7 @@ class OpenAPISchemaParser: operations[operation_id] = { "method": method.upper(), "path": path, - "summary": operation.get("summary", ""), + "summary": summary if summary else operation_id, "description": operation.get("description", ""), "parameters": self._extract_parameters(operation), "request_body": self._extract_request_body(operation), diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index ea5fdb96..51415732 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -232,7 +232,7 @@ class LangchainAdapter: # 添加验证约束 if param.enum: # 枚举值约束 - field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" + field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$" if param.minimum is not None: field_kwargs["ge"] = param.minimum @@ -241,7 +241,7 @@ class LangchainAdapter: field_kwargs["le"] = param.maximum if param.pattern: - field_kwargs["regex"] = param.pattern + field_kwargs["pattern"] = param.pattern fields[param.name] = Field(**field_kwargs) annotations[param.name] = python_type diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 2901b7ca..e513a147 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -27,20 +27,22 @@ class SimpleMCPClient: # 确定连接类型 self.is_websocket = server_url.startswith(("ws://", "wss://")) + self.is_sse = "/sse" in server_url.lower() # 连接状态 self._websocket = None self._session = None self._request_id = 0 self._pending_requests = {} + self._server_capabilities = {} + self._endpoint_url = None # SSE endpoint URL + self._sse_task = None async def __aenter__(self): - """异步上下文管理器入口""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" await self.disconnect() async def connect(self): @@ -57,47 +59,157 @@ class SimpleMCPClient: async def disconnect(self): """断开连接""" try: + if self._sse_task: + self._sse_task.cancel() if self._websocket: await self._websocket.close() self._websocket = None - if self._session: await self._session.close() self._session = None - except Exception as e: logger.error(f"断开连接失败: {e}") async def _connect_websocket(self): """WebSocket 连接""" headers = self._build_headers() - self._websocket = await websockets.connect( self.server_url, extra_headers=headers, timeout=self.timeout ) - - # 启动消息处理 asyncio.create_task(self._handle_websocket_messages()) - - # 发送初始化消息 await self._send_initialize() async def _connect_http(self): """HTTP 连接""" headers = self._build_headers() timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(headers=headers, timeout=timeout) - self._session = aiohttp.ClientSession( - headers=headers, - timeout=timeout - ) - - # 对于 ModelScope MCP 服务,需要先发送初始化请求 - if "modelscope.net" in self.server_url: + if self.is_sse: + await self._initialize_sse_session() + elif "modelscope.net" in self.server_url: await self._initialize_modelscope_session() + async def _initialize_sse_session(self): + """初始化 SSE MCP 会话 - 参考 Dify 实现""" + try: + # 建立 SSE 连接 + response = await self._session.get( + self.server_url, + headers={"Accept": "text/event-stream"} + ) + + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") + + # 启动 SSE 读取任务 + self._sse_task = asyncio.create_task(self._read_sse_stream(response)) + + # 等待获取 endpoint URL + for _ in range(10): + if self._endpoint_url: + break + await asyncio.sleep(1) + + if not self._endpoint_url: + raise MCPConnectionError("未能获取 endpoint URL") + + # 发送 initialize 请求到 endpoint + init_request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} + } + } + + init_response = await self._send_sse_request(init_request) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + result = init_response.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + # 发送 initialized 通知 + await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"}) + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"初始化连接失败: {e}") + + async def _read_sse_stream(self, response): + """读取 SSE 流""" + try: + async for line in response.content: + line = line.decode('utf-8').strip() + + if line.startswith('event:'): + continue + + if line.startswith('data:'): + data = line[5:].strip() # 去除 'data:' 后的空格 + if not data or data == '[DONE]': + continue + + try: + # 处理 endpoint 事件(相对路径或绝对路径) + if not self._endpoint_url: + # 如果是相对路径,拼接成完整 URL + if data.startswith('/'): + from urllib.parse import urlparse, urlunparse + parsed = urlparse(self.server_url) + self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}" + else: + self._endpoint_url = data + logger.info(f"获取到 endpoint URL: {self._endpoint_url}") + continue + + # 处理 message 事件 + message = json.loads(data) + request_id = message.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(message) + except json.JSONDecodeError: + continue + except Exception as e: + logger.error(f"SSE 流读取错误: {e}") + + async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]: + """通过 SSE endpoint 发送请求""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + request_id = request["id"] + future = asyncio.Future() + self._pending_requests[request_id] = future + + try: + async with self._session.post(self._endpoint_url, json=request) as response: + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") + + return await asyncio.wait_for(future, timeout=self.timeout) + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise MCPConnectionError("请求超时") + + async def _send_sse_notification(self, notification: Dict[str, Any]): + """发送通知(无需响应)""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + async with self._session.post(self._endpoint_url, json=notification) as response: + if response.status != 200: + logger.warning(f"通知发送失败: {response.status}") + async def _initialize_modelscope_session(self): """初始化 ModelScope MCP 会话""" init_request = { @@ -107,18 +219,12 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } try: - async with self._session.post( - self.server_url, - json=init_request - ) as response: + async with self._session.post(self.server_url, json=init_request) as response: if response.status != 200: error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") @@ -127,21 +233,16 @@ class SimpleMCPClient: if "error" in init_response: raise MCPConnectionError(f"初始化失败: {init_response['error']}") - # 获取 session ID session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") if session_id: self._session.headers.update({"Mcp-Session-Id": session_id}) - # 发送 initialized 通知 initialized_notification = { "jsonrpc": "2.0", "method": "notifications/initialized" } - async with self._session.post( - self.server_url, - json=initialized_notification - ) as notif_response: + async with self._session.post(self.server_url, json=initialized_notification): pass except aiohttp.ClientError as e: @@ -149,12 +250,18 @@ class SimpleMCPClient: def _build_headers(self) -> Dict[str, str]: """构建请求头""" + # 基础 headers headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream" } - # 添加认证头 + # 合并 connection_config 中的自定义 headers + custom_headers = self.connection_config.get("headers", {}) + if custom_headers: + headers.update(custom_headers) + + # 处理认证配置(认证 headers 优先级更高) auth_config = self.connection_config.get("auth_config", {}) auth_type = self.connection_config.get("auth_type", "none") @@ -178,7 +285,7 @@ class SimpleMCPClient: return headers async def _send_initialize(self): - """发送初始化消息""" + """发送初始化消息(WebSocket)""" init_message = { "jsonrpc": "2.0", "id": self._get_request_id(), @@ -186,124 +293,90 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } await self._websocket.send(json.dumps(init_message)) + response = await self._websocket.recv() + response_data = json.loads(response) - # 等待初始化响应 - response = await asyncio.wait_for( - self._websocket.recv(), - timeout=self.timeout - ) + if "error" in response_data: + raise MCPConnectionError(f"初始化失败: {response_data['error']}") - init_response = json.loads(response) - if "error" in init_response: - raise MCPConnectionError(f"初始化失败: {init_response['error']}") + result = response_data.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + await self._websocket.send(json.dumps({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + })) + + async def list_tools(self) -> List[Dict[str, Any]]: + """获取工具列表""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/list" + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}") + + result = response_data.get("result", {}) + return result.get("tools", []) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """调用工具""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments} + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + error = response_data["error"] + raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response_data.get("result", {}) + + def _get_request_id(self) -> int: + """生成请求 ID""" + self._request_id += 1 + return self._request_id async def _handle_websocket_messages(self): """处理 WebSocket 消息""" try: - while self._websocket and not self._websocket.closed: - try: - message = await self._websocket.recv() - data = json.loads(message) - - # 处理响应 - if "id" in data: - request_id = str(data["id"]) - if request_id in self._pending_requests: - future = self._pending_requests.pop(request_id) - if not future.done(): - future.set_result(data) - - except ConnectionClosed: - break - except Exception as e: - logger.error(f"处理WebSocket消息失败: {e}") - + async for message in self._websocket: + data = json.loads(message) + request_id = data.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(data) + except ConnectionClosed: + logger.info("WebSocket 连接已关闭") except Exception as e: - logger.error(f"WebSocket消息处理异常: {e}") - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments - } - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") - - return response.get("result", {}) - - async def list_tools(self) -> List[Dict[str, Any]]: - """获取工具列表""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/list", - "params": {} - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}") - - result = response.get("result", {}) - return result.get("tools", []) - - async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送WebSocket请求""" - request_id = str(request_data["id"]) - future = asyncio.Future() - self._pending_requests[request_id] = future - - try: - await self._websocket.send(json.dumps(request_data)) - response = await asyncio.wait_for(future, timeout=self.timeout) - return response - except asyncio.TimeoutError: - self._pending_requests.pop(request_id, None) - raise - - async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送HTTP请求""" - try: - async with self._session.post( - self.server_url, - json=request_data - ) as response: - if response.status != 200: - error_text = await response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() - - except aiohttp.ClientError as e: - raise MCPConnectionError(f"HTTP请求失败: {e}") - - def _get_request_id(self) -> str: - """获取请求ID""" - self._request_id += 1 - return f"req_{self._request_id}_{int(time.time() * 1000)}" \ No newline at end of file + logger.error(f"WebSocket 消息处理错误: {e}") diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 67689935..c048f447 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -74,6 +74,7 @@ class WorkflowExecutor: 初始化的工作流状态 """ user_message = input_data.get("message") or "" + conversation_messages = input_data.get("conv_messages") or [] # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) config_variables_list = self.workflow_config.get("variables") or [] @@ -114,7 +115,7 @@ class WorkflowExecutor: } return { - "messages": [('user', user_message)], + "messages": conversation_messages, "variables": variables, "node_outputs": {}, "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index e3bf36c9..72fd0bb5 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -7,13 +7,13 @@ import asyncio import logging from abc import ABC, abstractmethod -from operator import add from typing import Any -from langchain_core.messages import AnyMessage, AIMessage +from langchain_core.messages import AIMessage from langgraph.config import get_stream_writer from typing_extensions import TypedDict, Annotated +from app.core.config import settings from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: Annotated[list[tuple[str, str]], add] + messages: list[dict[str, str]] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list @@ -154,7 +154,7 @@ class BaseNode(ABC): Returns: 超时时间 """ - return 60 + return settings.WORKFLOW_NODE_TIMEOUT # return self.error_handling.get("timeout", 60) async def run(self, state: WorkflowState) -> dict[str, Any]: @@ -203,6 +203,7 @@ class BaseNode(ABC): # 返回包装后的输出和运行时变量 return { **wrapped_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var @@ -356,6 +357,7 @@ class BaseNode(ABC): # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6195afbd..0cbd9e8e 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -6,7 +6,6 @@ End 节点实现 import logging import re -import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import NodeType @@ -38,7 +37,23 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) else: + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + ]) output = "工作流已完成" # 统计信息(用于日志) @@ -166,6 +181,12 @@ class EndNode(BaseNode): "chunk_index": 1, "is_suffix": False }) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + } + ]) yield {"__final__": True, "result": output} return @@ -176,7 +197,6 @@ class EndNode(BaseNode): source_node_id = edge.get("source") # Check if the source node is an LLM node for node in self.workflow_config.get("nodes", []): - print("="*50) logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: direct_upstream_llm_nodes.append(source_node_id) @@ -216,12 +236,24 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) + # yield completion marker yield {"__final__": True, "result": output} return # Has reference to direct upstream LLM node, only output the part after that reference (suffix) - logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") + logger.info( + f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") # Collect suffix parts suffix_parts = [] @@ -258,6 +290,17 @@ class EndNode(BaseNode): # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": full_output + } + ]) + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") @@ -280,7 +323,8 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!" + f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") # 统计信息 node_outputs = state.get("node_outputs", {}) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 8498fc38..f65d5879 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -11,12 +11,12 @@ class MessageConfig(BaseModel): """消息配置""" role: str = Field( - ..., + default='user', description="消息角色:system, user, assistant" ) content: str = Field( - ..., + default="", description="消息内容,支持模板变量,如:{{ sys.message }}" ) @@ -30,6 +30,23 @@ class MessageConfig(BaseModel): return v.lower() +class MemoryWindowSetting(BaseModel): + enable: bool = Field( + default=False, + description="启用记忆" + ) + + enable_window: bool = Field( + default=False, + description="启用记忆窗口" + ) + + window_size: int = Field( + default=20, + description="记忆窗口大小" + ) + + class LLMNodeConfig(BaseNodeConfig): """LLM 节点配置 @@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig): description="上下文" ) + memory: MemoryWindowSetting = Field( + ..., + description="对话上下文窗口" + ) + # 简单模式 prompt: str | None = Field( default=None, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bfa1b99f..e25bd35d 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -85,28 +85,31 @@ class LLMNode(BaseNode): """ # 1. 处理消息格式(优先使用 messages) - messages_config = self.config.get("messages") + messages_config = self.typed_config.messages if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: - role = msg_config.get("role", "user").lower() - content_template = msg_config.get("content", "") + role = msg_config.role.lower() + content_template = msg_config.content content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) # 根据角色创建对应的消息对象 if role == "system": - messages.append(SystemMessage(content=content)) + messages.append({"role": "system", "content": content}) elif role in ["user", "human"]: - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) elif role in ["ai", "assistant"]: - messages.append(AIMessage(content=content)) + messages.append({"role": "assistant", "content": content}) else: logger.warning(f"未知的消息角色: {role},默认使用 user") - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) + if self.typed_config.memory.enable: + # if self.typed_config.memory.enable_window: + messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -189,7 +192,7 @@ class LLMNode(BaseNode): return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ - {"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} + {"role": msg.get("role"), "content": msg.get("content", "")} for msg in prompt_or_messages ] if isinstance(prompt_or_messages, list) else None, "config": { diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index f1c99ddb..08a2b280 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -3,8 +3,9 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig -from app.db import get_db_read, get_db_context +from app.db import get_db_read from app.services.memory_agent_service import MemoryAgentService +from app.tasks import write_message_task class MemoryReadNode(BaseNode): @@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: - workspace_id = self.get_variable('sys.workspace_id', state) end_user_id = self.get_variable("sys.user_id", state) - if not workspace_id: - raise RuntimeError("Workspace id is required") if not end_user_id: raise RuntimeError("End user id is required") @@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode): self.typed_config = MemoryWriteNodeConfig(**self.config) async def execute(self, state: WorkflowState) -> Any: - with get_db_context() as db: - workspace_id = self.get_variable('sys.workspace_id', state) - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", state) - if not workspace_id: - raise RuntimeError("Workspace id is required") - if not end_user_id: - raise RuntimeError("End user id is required") + if not end_user_id: + raise RuntimeError("End user id is required") - return await MemoryAgentService().write_memory( - group_id=end_user_id, - message=self._render_template(self.typed_config.message, state), - config_id=str(self.typed_config.config_id), - db=db, - storage_type="neo4j", - user_rag_memory_id="" - ) + write_message_task.delay( + end_user_id, + self._render_template(self.typed_config.message, state), + str(self.typed_config.config_id), + "neo4j", + "" + ) + + return "success" diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 3c00e5a0..35d2e424 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -41,6 +41,7 @@ class ToolConfig(BaseModel): tool_id: Optional[str] = Field(default=None, description="工具ID") operation: Optional[str] = Field(default=None, description="工具特定配置") + class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -348,6 +349,7 @@ class AppChatRequest(BaseModel): variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") + class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 56400c92..0065c64b 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db, get_db_context from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig +from app.schemas import DraftRunRequest from app.services.tool_service import ToolService from app.repositories.tool_repository import ToolRepository from app.db import get_db @@ -59,7 +60,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -210,7 +211,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -511,7 +512,6 @@ class AppChatService: } ) - except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("多 Agent 流式聊天被中断") @@ -537,83 +537,19 @@ class AppChatService: ) -> Dict[str, Any]: """聊天(非流式)""" workflow_service = WorkflowService(self.db) - - input_data = {"message":message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id + ) + return await workflow_service.run( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, ) - - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - result = await execute_workflow( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ) - - # 更新执行结果 - if result.get("status") == "completed": - workflow_service.update_execution_status( - execution.execution_id, - "completed", - output_data=result.get("node_outputs", {}) - ) - else: - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=result.get("error") - ) - - # 返回增强的响应结构 - return { - "execution_id": execution.execution_id, - "status": result.get("status"), - "output": result.get("output"), # 最终输出(字符串) - "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) - "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID - "error_message": result.get("error"), - "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") - } - - except Exception as e: - logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - raise BusinessException( - code=BizCode.INTERNAL_ERROR, - message=f"工作流执行失败: {str(e)}" - ) async def workflow_chat_stream( self, @@ -632,62 +568,21 @@ class AppChatService: ) -> AsyncGenerator[str, None]: """聊天(流式)""" workflow_service = WorkflowService(self.db) - input_data = {"message": message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id ) + async for event in workflow_service.run_stream( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, + ): + yield event - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 流式执行工作流 - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in workflow_service._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ): - # 直接转发 executor 的事件(已经是正确的格式) - yield event - - except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - # 发送错误事件 - yield { - "event": "error", - "data": { - "execution_id": execution.execution_id, - "error": str(e) - } - } # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 5011e83e..8f25f477 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -13,12 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple from app.core.logging_config import get_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context +from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.schemas.memory_episodic_schema import type_mapping, EmotionType, EmotionSubject - +from app.services.implicit_memory_service import ImplicitMemoryService from app.services.memory_base_service import MemoryBaseService from app.services.memory_config_service import MemoryConfigService +from app.services.memory_perceptual_service import MemoryPerceptualService +from app.services.memory_short_service import ShortService from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -1198,18 +1200,17 @@ async def analytics_memory_types( end_user_id: Optional[str] = None ) -> List[Dict[str, Any]]: """ - 统计9种记忆类型的数量和百分比 + 统计8种记忆类型的数量和百分比 计算规则: - 1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity - 2. 工作记忆 (WORKING_MEMORY) = chunk + entity - 3. 短期记忆 (SHORT_TERM_MEMORY) = chunk - 4. 长期记忆 (LONG_TERM_MEMORY) = entity - 5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) - 6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity - 7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) - 8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) - 9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) + 1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count + 2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取) + 3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量 + 4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) + 5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一 + 6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) + 7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) + 8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) Args: db: 数据库会话 @@ -1229,7 +1230,6 @@ async def analytics_memory_types( - PERCEPTUAL_MEMORY: 感知记忆 - WORKING_MEMORY: 工作记忆 - SHORT_TERM_MEMORY: 短期记忆 - - LONG_TERM_MEMORY: 长期记忆 - EXPLICIT_MEMORY: 显性记忆 - IMPLICIT_MEMORY: 隐性记忆 - EMOTIONAL_MEMORY: 情绪记忆 @@ -1239,40 +1239,78 @@ async def analytics_memory_types( # 初始化基础服务 base_service = MemoryBaseService() - # 定义需要查询的基础节点类型 - node_types = { - "Statement": "Statement", - "Entity": "ExtractedEntity", - "Chunk": "Chunk" - } + # 初始化感知记忆服务 + perceptual_service = MemoryPerceptualService(db) - # 存储每种节点类型的计数 - node_counts = {} + # 获取感知记忆数量 + if end_user_id: + perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id)) + perceptual_count = perceptual_stats.get("total", 0) + else: + perceptual_count = 0 - # 查询每种节点类型的数量 - for key, node_type in node_types.items(): - if end_user_id: - query = f""" - MATCH (n:{node_type}) + # 获取工作记忆数量(基于会话数量) + work_count = 0 + if end_user_id: + try: + conversation_repo = ConversationRepository(db) + conversations = conversation_repo.get_conversation_by_user_id( + user_id=uuid.UUID(end_user_id), + limit=100, # 获取更多会话以准确统计 + is_activate=True + ) + work_count = len(conversations) + logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}") + work_count = 0 + + # 获取隐性记忆数量(基于 Statement 节点数量的三分之一) + implicit_count = 0 + if end_user_id: + try: + # 查询 Statement 节点数量 + query = """ + MATCH (n:Statement) WHERE n.group_id = $group_id RETURN count(n) as count """ result = await _neo4j_connector.execute_query(query, group_id=end_user_id) - else: - query = f""" - MATCH (n:{node_type}) - RETURN count(n) as count - """ - result = await _neo4j_connector.execute_query(query) - - # 提取计数结果 - count = result[0]["count"] if result and len(result) > 0 else 0 - node_counts[key] = count + statement_count = result[0]["count"] if result and len(result) > 0 else 0 + # 取三分之一作为隐性记忆数量 + implicit_count = round(statement_count / 3) + logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}") + implicit_count = 0 - # 获取各节点类型的数量 - statement_count = node_counts.get("Statement", 0) - entity_count = node_counts.get("Entity", 0) - chunk_count = node_counts.get("Chunk", 0) + # 原有的基于行为习惯的统计方式(已注释) + # implicit_count = 0 + # if end_user_id: + # try: + # implicit_service = ImplicitMemoryService(db, end_user_id) + # behavior_habits = await implicit_service.get_behavior_habits( + # user_id=end_user_id + # ) + # implicit_count = len(behavior_habits) + # logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})") + # except Exception as e: + # logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}") + # implicit_count = 0 + + # 获取短期记忆数量(基于 /short_term 接口返回的问答对数量) + short_term_count = 0 + if end_user_id: + try: + short_term_service = ShortService(end_user_id) + short_term_data = short_term_service.get_short_databasets() + # 统计 short_term 数组的长度 + if short_term_data: + short_term_count = len(short_term_data) + logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}") + short_term_count = 0 # 获取用户的遗忘阈值配置 forgetting_threshold = 0.3 # 默认值 @@ -1298,17 +1336,16 @@ async def analytics_memory_types( # 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量 episodic_count = await base_service.get_episodic_memory_count(end_user_id) explicit_count = await base_service.get_explicit_memory_count(end_user_id) - emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count) + emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count) forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold) - # 按规则计算9种记忆类型的数量(使用英文枚举作为key) + # 按规则计算8种记忆类型的数量(使用英文枚举作为key) memory_counts = { - "PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆 - "WORKING_MEMORY": chunk_count + entity_count, # 工作记忆 - "SHORT_TERM_MEMORY": chunk_count, # 短期记忆 - "LONG_TERM_MEMORY": entity_count, # 长期记忆 + "PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆 + "WORKING_MEMORY": work_count, # 工作记忆(基于会话数量) + "SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量) "EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆) - "IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity) + "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3) "EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计) "EPISODIC_MEMORY": episodic_count, # 情景记忆 "FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 7d3c784f..f9988352 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -2,12 +2,11 @@ 工作流服务层 """ import datetime -import json import logging import uuid -import datetime from typing import Any, Annotated, AsyncGenerator +from deprecated import deprecated from fastapi import Depends from sqlalchemy.orm import Session @@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config from app.db import get_db, get_db_context from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories.conversation_repository import MessageRepository +from app.models.conversation_model import Message from app.repositories.end_user_repository import EndUserRepository -from app.services.multi_agent_service import convert_uuids_to_str from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) from app.schemas import DraftRunRequest -from app.utils.sse_utils import format_sse_message +from app.services.multi_agent_service import convert_uuids_to_str logger = logging.getLogger(__name__) @@ -37,6 +37,7 @@ class WorkflowService: self.config_repo = WorkflowConfigRepository(db) self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) + self.message_repo = MessageRepository(db) # ==================== 配置管理 ==================== @@ -418,14 +419,13 @@ class WorkflowService: """运行工作流 Args: + workspace_id: + config: + payload: app_id: 应用 ID - input_data: 输入数据(包含 message 和 variables) - triggered_by: 触发用户 ID - conversation_id: 会话 ID(可选) - stream: 是否流式返回 Returns: - 执行结果(非流式)或生成器(流式) + 执行结果(非流式) Raises: BusinessException: 配置不存在或执行失败时抛出 @@ -438,7 +438,8 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id} # 转换 user_id 为 UUID triggered_by_uuid = None @@ -461,7 +462,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -500,8 +501,11 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, @@ -517,6 +521,17 @@ class WorkflowService: "completed", output_data=result ) + final_messages = result.get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: self.update_execution_status( execution.execution_id, @@ -529,6 +544,7 @@ class WorkflowService: "execution_id": execution.execution_id, "status": result.get("status"), "variables": result.get("variables"), + "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID @@ -559,6 +575,7 @@ class WorkflowService: """运行工作流(流式) Args: + workspace_id: app_id: 应用 ID payload: 请求对象(包含 message, variables, conversation_id 等) config: 存储类型(可选) @@ -601,7 +618,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -638,17 +655,46 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + from app.core.workflow.executor import execute_workflow_stream - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in self._run_workflow_stream( + async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=end_user_id ): - # 直接转发 executor 的事件(已经是正确的格式) + if event.get("event") == "workflow_end": + + status = event.get("data", {}).get("status") + if status == "completed": + self.update_execution_status( + execution.execution_id, + "completed", + output_data=event.get("data") + ) + final_messages = event.get("data", {}).get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") + elif status == "failed": + self.update_execution_status( + execution.execution_id, + "failed", + output_data=event.get("data") + ) + else: + logger.error(f"unexpect workflow run status, status: {status}") yield event except Exception as e: @@ -667,6 +713,8 @@ class WorkflowService: } } + @deprecated(reason="This method is deprecated. " + "Please use WorkflowService.run / run_stream instead.") async def run_workflow( self, app_id: uuid.UUID, @@ -819,6 +867,7 @@ class WorkflowService: return clean_value(event) + @deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.") async def _run_workflow_stream( self, workflow_config: dict[str, Any], diff --git a/api/pyproject.toml b/api/pyproject.toml index 2dcc706d..6da684de 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -136,7 +136,8 @@ dependencies = [ "markdown-to-json==2.1.1", "valkey==6.0.2", "python-calamine>=0.4.0", - "xlrd==2.0.2" + "xlrd==2.0.2", + "deprecated>=1.3.1", ] [tool.pytest.ini_options] diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 11ccb5c3..c90f9208 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -55,7 +55,7 @@ const ChatContent: FC = ({ } {/* 消息气泡框 */} -
= ({
{/* 底部标签(如时间戳、用户名等) */} {labelPosition === 'bottom' && -
+
{labelFormat(item)}
} diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 27986e76..2d2d96a2 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1265,6 +1265,7 @@ export const en = { emotionLine: 'Emotion Changes Over Time', interaction: 'Interaction Frequency & Relationship Stages', timelines_memory: 'All', + Chunk: 'Chunk', MemorySummary: 'Long-term Accumulation', Statement: 'Emotional Memory', ExtractedEntity: 'Episodic Memory', @@ -1786,6 +1787,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re temperature: 'Temperature', max_tokens: 'Max Tokens', context: 'Context', + memory: 'Memory', + enable_window: 'Memory Window', + inner: 'Built-in', }, start: { variables: 'Input Fields', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 8fa73fd3..53e71c84 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1343,6 +1343,7 @@ export const zh = { emotionLine: '情绪随时间变化', interaction: '互动频率 & 关系阶段', timelines_memory: '全部', + Chunk: '工作记忆', MemorySummary: '长期沉淀', Statement: '情绪记忆', ExtractedEntity: '情景记忆', @@ -1883,6 +1884,9 @@ export const zh = { temperature: '温度', max_tokens: '最大令牌数', context: '上下文', + memory: '记忆', + enable_window: '记忆窗口', + inner: '内置', }, start: { variables: '输入字段', diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index ce51c622..81f902cb 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -176,6 +176,9 @@ const Agent = forwardRef((_props, ref) => { if (response?.knowledge_retrieval?.knowledge_bases?.length) { getDefaultKnowledgeList(response) } + if (response?.tools?.length) { + setToolList(response?.tools) + } }).finally(() => { setLoading(false) }) diff --git a/web/src/views/ApplicationConfig/components/ToolList.tsx b/web/src/views/ApplicationConfig/components/ToolList.tsx index 9834b186..fde7286b 100644 --- a/web/src/views/ApplicationConfig/components/ToolList.tsx +++ b/web/src/views/ApplicationConfig/components/ToolList.tsx @@ -79,8 +79,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi } }, [data]) - console.log('toolList', toolList) - const handleAddTool = () => { toolModalRef.current?.handleOpen() } diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index d385e1f0..6ccb35ec 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -259,9 +259,10 @@ const Conversation: FC = () => {
+
} - contentClassName="rb:h-[calc(100%-152px)]" + empty={} + contentClassName="rb:h-[calc(100%-152px)] " data={chatList} streamLoading={streamLoading} loading={loading} @@ -290,6 +291,7 @@ const Conversation: FC = () => { +
) diff --git a/web/src/views/Workflow/components/AddChatVariable/index.tsx b/web/src/views/Workflow/components/AddChatVariable/index.tsx index 7ebce7df..d6741bf9 100644 --- a/web/src/views/Workflow/components/AddChatVariable/index.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/index.tsx @@ -1,6 +1,5 @@ -import React, { useState, useImperativeHandle, forwardRef, useRef } from 'react'; -import { Button, Input, Space, Typography, Tooltip, message, List } from 'antd'; -import { PlusOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons'; +import { useState, useImperativeHandle, forwardRef, useRef } from 'react'; +import { Button, Space, List } from 'antd'; import { useTranslation } from 'react-i18next'; import type { ChatVariable, AddChatVariableRef } from '../../types'; import type { ChatVariableModalRef } from './types' diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx index 251409dd..e0912c0a 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx @@ -131,7 +131,7 @@ const EditableTable: React.FC = ({ const AddButton = ({ block = false }: { block?: boolean }) => ( diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index a555b4fe..8752121e 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -22,6 +22,7 @@ import ConditionList from './ConditionList' import CycleVarsList from './CycleVarsList' import AssignmentList from './AssignmentList' import ToolConfig from './ToolConfig' +import MemoryConfig from './MemoryConfig' // import { calculateVariableList } from './utils/variableListCalculator' interface PropertiesProps { @@ -1230,6 +1231,20 @@ const Properties: FC = ({ ) } + if (config.type === 'memoryConfig') { + return ( + + + + ) + } return ( boolean | void; parseEvent: () => boolean | void; handleSave: (flag?: boolean) => Promise; + chatVariables: ChatVariable[]; + setChatVariables: React.Dispatch>; } export const edge_color = '#155EEF'; @@ -54,6 +56,7 @@ export const useWorkflowGraph = ({ const [canRedo, setCanRedo] = useState(false); const [isHandMode, setIsHandMode] = useState(false); const [config, setConfig] = useState(null); + const [chatVariables, setChatVariables] = useState([]) useEffect(() => { getConfig() @@ -63,16 +66,15 @@ export const useWorkflowGraph = ({ getWorkflowConfig(id) .then(res => { const { variables, ...rest } = res as WorkflowConfig - setConfig({ - ...rest, - variables: variables.map(v => { - const { default: _, ...cleanV } = v - return { - ...cleanV, - defaultValue: v.default ?? '' - } - }) + const initChatVariables = variables.map(v => { + const { default: _, ...cleanV } = v + return { + ...cleanV, + defaultValue: v.default ?? '' + } }) + setChatVariables(initChatVariables) + setConfig({ ...rest, variables: initChatVariables }) }) } @@ -94,7 +96,17 @@ export const useWorkflowGraph = ({ if (nodeLibraryConfig?.config) { Object.keys(nodeLibraryConfig.config).forEach(key => { - if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { + if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { + const { memory, messages } = config as any; + if (memory?.enable && messages && messages.length > 0) { + const lastMessage = messages[messages.length - 1] + nodeLibraryConfig.config[key].defaultValue = { + ...memory, + messages: lastMessage.content + } + nodeLibraryConfig.config.messages.defaultValue.splice(-1, 1) + } + } else if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { const { query, ...rest } = config nodeLibraryConfig.config[key].defaultValue = { ...rest @@ -917,13 +929,13 @@ export const useWorkflowGraph = ({ const params = { ...config, - variables: config.variables.map(v => { - const { defaultValue, ...cleanV } = v - return { - ...cleanV, - default: defaultValue ?? '' - } - }), + variables: chatVariables.map(v => { + const { defaultValue, ...cleanV } = v + return { + ...cleanV, + default: defaultValue ?? '' + } + }), nodes: nodes.map((node: Node) => { const data = node.getData(); const position = node.getPosition(); @@ -931,7 +943,15 @@ export const useWorkflowGraph = ({ if (data.config) { Object.keys(data.config).forEach(key => { - if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') { + if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) { + const { messages, ...rest } = data.config[key].defaultValue + let memoryMessage = { role: 'USER', content: data.config[key].defaultValue.messages } + itemConfig = { + ...itemConfig, + messages: rest.enable ? [...itemConfig.messages, memoryMessage] : itemConfig.messages, + memory: { ...rest }, + } + } else if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') { let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue if (data.config.group.defaultValue) { data.config[key].defaultValue.map((vo: any) => { @@ -1077,5 +1097,7 @@ export const useWorkflowGraph = ({ copyEvent, parseEvent, handleSave, + chatVariables, + setChatVariables }; }; diff --git a/web/src/views/Workflow/index.tsx b/web/src/views/Workflow/index.tsx index 70803373..ba17a63a 100644 --- a/web/src/views/Workflow/index.tsx +++ b/web/src/views/Workflow/index.tsx @@ -8,7 +8,7 @@ import PortClickHandler from './components/PortClickHandler'; import { useWorkflowGraph } from './hooks/useWorkflowGraph'; import type { WorkflowRef } from '@/views/ApplicationConfig/types' import Chat from './components/Chat/Chat'; -import type { ChatRef, AddChatVariableRef, ChatVariable } from './types' +import type { ChatRef, AddChatVariableRef } from './types' import arrowIcon from '@/assets/images/workflow/arrow.png' import AddChatVariable from './components/AddChatVariable'; @@ -21,7 +21,6 @@ const Workflow = forwardRef((_props, ref) => { // 使用自定义Hook初始化工作流图 const { config, - setConfig, graphRef, selectedNode, setSelectedNode, @@ -38,6 +37,8 @@ const Workflow = forwardRef((_props, ref) => { copyEvent, parseEvent, handleSave, + chatVariables, + setChatVariables } = useWorkflowGraph({ containerRef, miniMapRef }); const onDragOver = (event: React.DragEvent) => { @@ -52,15 +53,6 @@ const Workflow = forwardRef((_props, ref) => { const addVariable = () => { addChatVariableRef.current?.handleOpen() } - const handleUpdateChatVariable = (variables: ChatVariable[]) => { - setConfig(prev => { - if (!prev) return null - return { - ...prev, - variables - } - }) - } useImperativeHandle(ref, () => ({ handleSave, @@ -125,8 +117,8 @@ const Workflow = forwardRef((_props, ref) => { );