diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index f55ea5b5..43f177ef 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -11,15 +11,16 @@ from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user, cur_workspace_access_guard from app.models import User -from app.models.app_model import AppType, App +from app.models.app_model import AppType from app.repositories import knowledge_repository +from app.repositories.end_user_repository import EndUserRepository from app.schemas import app_schema from app.schemas.response_schema import PageData, PageMeta +from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema from app.schemas.workflow_schema import WorkflowConfigUpdate from app.services import app_service, workspace_service from app.services.agent_config_helper import enrich_agent_config from app.services.app_service import AppService -from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema from app.services.workflow_service import WorkflowService, get_workflow_service router = APIRouter(prefix="/apps", tags=["Apps"]) @@ -405,6 +406,15 @@ async def draft_run( # 只读操作,允许访问共享应用 service._validate_app_accessible(app, workspace_id) + if payload.user_id is None: + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, + other_id=str(current_user.id), + original_user_id=str(current_user.id) # Save original user_id to other_id + ) + payload.user_id = str(new_end_user.id) + # 处理会话ID(创建或验证) conversation_id = await draft_service._ensure_conversation( conversation_id=payload.conversation_id, diff --git a/api/app/controllers/multi_agent_controller.py b/api/app/controllers/multi_agent_controller.py index 55614dea..dbcc2536 100644 --- a/api/app/controllers/multi_agent_controller.py +++ b/api/app/controllers/multi_agent_controller.py @@ -74,7 +74,7 @@ def get_multi_agent_configs( "app_id": str(app_id), "default_model_config_id": None, "model_parameters": None, - "orchestration_mode": "conditional", + "orchestration_mode": "supervisor", "sub_agents": [], "routing_rules": [], "execution_config": { diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 354a58ef..04da05df 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -466,7 +466,7 @@ async def chat( conversation_id=conversation.id, # 使用已创建的会话 ID user_id=str(new_end_user.id), # 转换为字符串 variables=payload.variables, - config= payload.agent_config, + config=agent_config, web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, @@ -565,11 +565,12 @@ async def chat( config = workflow_config_4_app_release(release) if payload.stream: async def event_generator(): + async for event in app_chat_service.workflow_chat_stream( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=new_end_user.id, # 转换为字符串 + user_id=end_user_id, # 转换为字符串 variables=payload.variables, config=config, web_search=payload.web_search, @@ -601,7 +602,7 @@ async def chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=new_end_user.id, # 转换为字符串 + user_id=end_user_id, # 转换为字符串 variables=payload.variables, config=config, web_search=payload.web_search, diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index 429aa67e..c6d9ddab 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"]) @router.post("/{app_id}/workflow") @cur_workspace_access_guard() async def create_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - config: WorkflowConfigCreate, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + config: WorkflowConfigCreate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """创建工作流配置 @@ -96,6 +96,7 @@ async def create_workflow_config( msg=f"创建工作流配置失败: {str(e)}" ) + # # @router.get("/{app_id}/workflow") # async def get_workflow_config( @@ -199,10 +200,10 @@ async def create_workflow_config( @router.delete("/{app_id}/workflow") async def delete_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """删除工作流配置 @@ -243,11 +244,11 @@ async def delete_workflow_config( @router.post("/{app_id}/workflow/validate") async def validate_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - for_publish: Annotated[bool, Query(description="是否为发布验证")] = False + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + for_publish: Annotated[bool, Query(description="是否为发布验证")] = False ): """验证工作流配置 @@ -312,12 +313,12 @@ async def validate_workflow_config( @router.get("/{app_id}/workflow/executions") async def get_workflow_executions( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - limit: Annotated[int, Query(ge=1, le=100)] = 50, - offset: Annotated[int, Query(ge=0)] = 0 + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + limit: Annotated[int, Query(ge=1, le=100)] = 50, + offset: Annotated[int, Query(ge=0)] = 0 ): """获取工作流执行记录列表 @@ -365,10 +366,10 @@ async def get_workflow_executions( @router.get("/workflow/executions/{execution_id}") async def get_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """获取工作流执行详情 @@ -417,16 +418,14 @@ async def get_workflow_execution( ) - # ==================== 工作流执行 ==================== - @router.post("/{app_id}/workflow/run") async def run_workflow( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - request: WorkflowExecutionRequest, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + request: WorkflowExecutionRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """执行工作流 @@ -487,22 +486,22 @@ async def run_workflow( """ try: async for event in await service.run_workflow( - app_id=app_id, - input_data=input_data, - triggered_by=current_user.id, - conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, - stream=True + app_id=app_id, + input_data=input_data, + triggered_by=current_user.id, + conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, + stream=True ): # 提取事件类型和数据 event_type = event.get("event", "message") event_data = event.get("data", {}) - + # 转换为标准 SSE 格式(字符串) # event: # data: sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" yield sse_message - + except Exception as e: logger.error(f"流式执行异常: {e}", exc_info=True) # 发送错误事件 @@ -554,10 +553,10 @@ async def run_workflow( @router.post("/workflow/executions/{execution_id}/cancel") async def cancel_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] ): """取消工作流执行 @@ -602,7 +601,7 @@ async def cancel_workflow_execution( except BusinessException as e: logger.warning(f"取消工作流执行失败: {e.message}") - return fail(code=e.error_code, msg=e.message) + return fail(code=e.code, msg=e.message) except Exception as e: logger.error(f"取消工作流执行异常: {e}", exc_info=True) return fail( diff --git a/api/app/core/config.py b/api/app/core/config.py index 573c4283..5f4f91c4 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -7,17 +7,18 @@ from dotenv import load_dotenv load_dotenv() + class Settings: ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "") - + # Neo4j Configuration (记忆系统数据库) NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687") NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j") NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "") - + # Database configuration (Postgres) DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1") DB_PORT: int = int(os.getenv("DB_PORT", "5432")) @@ -37,7 +38,7 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - + # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200")) @@ -48,7 +49,7 @@ class Settings: ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000")) ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true" ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10")) - + # Xinference configuration XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1") @@ -57,17 +58,17 @@ class Settings: LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true" LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "") LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "") - + # LLM Request Configuration LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0")) LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2")) - + # JWT Token Configuration SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random") ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")) REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")) - + # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" @@ -86,19 +87,19 @@ class Settings: LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") - + # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") # ======================================================================== # Internal Configuration (not in .env, used by application code) # ======================================================================== - + # Superuser settings (internal defaults) FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com") FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin") FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password") - + # Generic File Upload (internal) GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads") ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true" @@ -123,7 +124,7 @@ class Settings: LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5")) LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true" LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true" - + # Sensitive Data Filtering ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true" @@ -142,7 +143,6 @@ class Settings: LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB - # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) @@ -150,15 +150,15 @@ class Settings: HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) - REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) - + REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) + # Memory Cache Regeneration Configuration MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") - + # Tool Management Configuration TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) @@ -167,7 +167,10 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") - + + # workflow config + WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) + def get_memory_output_path(self, filename: str = "") -> str: """ Get the full path for memory module output files. @@ -182,7 +185,7 @@ class Settings: if filename: return str(base_path / filename) return str(base_path) - + def ensure_memory_output_dir(self) -> None: """ Ensure the memory output directory exists. diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index 23023ca4..cb0084b7 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -110,24 +110,24 @@ HTTP_MAPPING = { BizCode.TOKEN_EXPIRED: 401, BizCode.TOKEN_BLACKLISTED: 401, BizCode.FORBIDDEN: 403, - BizCode.TENANT_NOT_FOUND: 404, + BizCode.TENANT_NOT_FOUND: 400, BizCode.WORKSPACE_NO_ACCESS: 403, - BizCode.NOT_FOUND: 404, + BizCode.NOT_FOUND: 400, BizCode.USER_NOT_FOUND: 200, - BizCode.WORKSPACE_NOT_FOUND: 404, - BizCode.MODEL_NOT_FOUND: 404, - BizCode.KNOWLEDGE_NOT_FOUND: 404, - BizCode.DOCUMENT_NOT_FOUND: 404, - BizCode.FILE_NOT_FOUND: 404, - BizCode.APP_NOT_FOUND: 404, - BizCode.RELEASE_NOT_FOUND: 404, + BizCode.WORKSPACE_NOT_FOUND: 400, + BizCode.MODEL_NOT_FOUND: 400, + BizCode.KNOWLEDGE_NOT_FOUND: 400, + BizCode.DOCUMENT_NOT_FOUND: 400, + BizCode.FILE_NOT_FOUND: 400, + BizCode.APP_NOT_FOUND: 400, + BizCode.RELEASE_NOT_FOUND: 400, BizCode.DUPLICATE_NAME: 409, BizCode.RESOURCE_ALREADY_EXISTS: 409, BizCode.VERSION_ALREADY_EXISTS: 409, BizCode.STATE_CONFLICT: 409, BizCode.PUBLISH_FAILED: 500, BizCode.NO_DRAFT_TO_PUBLISH: 400, - BizCode.ROLLBACK_TARGET_NOT_FOUND: 404, + BizCode.ROLLBACK_TARGET_NOT_FOUND: 400, BizCode.APP_TYPE_NOT_SUPPORTED: 400, BizCode.AGENT_CONFIG_MISSING: 400, BizCode.SHARE_DISABLED: 403, diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py index 6d5012f1..0f306572 100644 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py @@ -425,15 +425,9 @@ async def Input_Summary( try: # Extract services from context - template_service = get_context_resource(ctx, "template_service") session_service = get_context_resource(ctx, "session_service") search_service = get_context_resource(ctx, "search_service") - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - # Resolve session ID sessionid = Resolve_username(usermessages) or "" sessionid = sessionid.replace('call_id_', '') @@ -539,31 +533,11 @@ async def Input_Summary( ) retrieve_info, question, raw_results = "", query, [] + # Return retrieved information directly without LLM processing + # Use the raw retrieved info as the answer + aimessages = retrieve_info if retrieve_info else "信息不足,无法回答" - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='input_summary', - query=query, - history=history, - retrieve_info=retrieve_info - ) - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - aimessages = structured.data.query_answer or "信息不足,无法回答" - except Exception as e: - logger.error( - f"Input_Summary: response_structured failed, using default answer: {e}", - exc_info=True - ) - aimessages = "信息不足,无法回答" - - logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}") + logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...") # Emit intermediate output for frontend return { diff --git a/api/app/core/tools/custom/schema_parser.py b/api/app/core/tools/custom/schema_parser.py index a22e2cfa..ea261dba 100644 --- a/api/app/core/tools/custom/schema_parser.py +++ b/api/app/core/tools/custom/schema_parser.py @@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger logger = get_business_logger() -# 为了兼容性,创建别名 -# SchemaParser = OpenAPISchemaParser = None - class OpenAPISchemaParser: """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" @@ -213,7 +210,9 @@ class OpenAPISchemaParser: if not isinstance(operation, dict): continue - + + summary = operation.get("summary", "") + # 生成操作ID operation_id = operation.get("operationId") if not operation_id: @@ -223,7 +222,7 @@ class OpenAPISchemaParser: operations[operation_id] = { "method": method.upper(), "path": path, - "summary": operation.get("summary", ""), + "summary": summary if summary else operation_id, "description": operation.get("description", ""), "parameters": self._extract_parameters(operation), "request_body": self._extract_request_body(operation), diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index ea5fdb96..51415732 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -232,7 +232,7 @@ class LangchainAdapter: # 添加验证约束 if param.enum: # 枚举值约束 - field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" + field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$" if param.minimum is not None: field_kwargs["ge"] = param.minimum @@ -241,7 +241,7 @@ class LangchainAdapter: field_kwargs["le"] = param.maximum if param.pattern: - field_kwargs["regex"] = param.pattern + field_kwargs["pattern"] = param.pattern fields[param.name] = Field(**field_kwargs) annotations[param.name] = python_type diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 2901b7ca..c082b314 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -27,20 +27,22 @@ class SimpleMCPClient: # 确定连接类型 self.is_websocket = server_url.startswith(("ws://", "wss://")) + self.is_sse = "/sse" in server_url.lower() # 连接状态 self._websocket = None self._session = None self._request_id = 0 self._pending_requests = {} + self._server_capabilities = {} + self._endpoint_url = None # SSE endpoint URL + self._sse_task = None async def __aenter__(self): - """异步上下文管理器入口""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" await self.disconnect() async def connect(self): @@ -57,47 +59,154 @@ class SimpleMCPClient: async def disconnect(self): """断开连接""" try: + if self._sse_task: + self._sse_task.cancel() if self._websocket: await self._websocket.close() self._websocket = None - if self._session: await self._session.close() self._session = None - except Exception as e: logger.error(f"断开连接失败: {e}") async def _connect_websocket(self): """WebSocket 连接""" headers = self._build_headers() - self._websocket = await websockets.connect( self.server_url, extra_headers=headers, timeout=self.timeout ) - - # 启动消息处理 asyncio.create_task(self._handle_websocket_messages()) - - # 发送初始化消息 await self._send_initialize() async def _connect_http(self): """HTTP 连接""" headers = self._build_headers() timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(headers=headers, timeout=timeout) - self._session = aiohttp.ClientSession( - headers=headers, - timeout=timeout - ) - - # 对于 ModelScope MCP 服务,需要先发送初始化请求 - if "modelscope.net" in self.server_url: + if self.is_sse: + await self._initialize_sse_session() + elif "modelscope.net" in self.server_url: await self._initialize_modelscope_session() + async def _initialize_sse_session(self): + """初始化 SSE MCP 会话 - 参考 Dify 实现""" + try: + # 建立 SSE 连接 + response = await self._session.get(self.server_url) + + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") + + # 启动 SSE 读取任务 + self._sse_task = asyncio.create_task(self._read_sse_stream(response)) + + # 等待获取 endpoint URL + for _ in range(10): + if self._endpoint_url: + break + await asyncio.sleep(1) + + if not self._endpoint_url: + raise MCPConnectionError("未能获取 endpoint URL") + + # 发送 initialize 请求到 endpoint + init_request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} + } + } + + init_response = await self._send_sse_request(init_request) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + result = init_response.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + # 发送 initialized 通知 + await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"}) + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"初始化连接失败: {e}") + + async def _read_sse_stream(self, response): + """读取 SSE 流""" + try: + async for line in response.content: + line = line.decode('utf-8').strip() + + if line.startswith('event:'): + continue + + if line.startswith('data:'): + data = line[5:].strip() # 去除 'data:' 后的空格 + if not data or data == '[DONE]': + continue + + try: + # 处理 endpoint 事件(相对路径或绝对路径) + if not self._endpoint_url: + # 如果是相对路径,拼接成完整 URL + if data.startswith('/'): + from urllib.parse import urlparse, urlunparse + parsed = urlparse(self.server_url) + self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}" + else: + self._endpoint_url = data + logger.info(f"获取到 endpoint URL: {self._endpoint_url}") + continue + + # 处理 message 事件 + message = json.loads(data) + request_id = message.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(message) + except json.JSONDecodeError: + continue + except Exception as e: + logger.error(f"SSE 流读取错误: {e}") + + async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]: + """通过 SSE endpoint 发送请求""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + request_id = request["id"] + future = asyncio.Future() + self._pending_requests[request_id] = future + + try: + async with self._session.post(self._endpoint_url, json=request) as response: + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") + + return await asyncio.wait_for(future, timeout=self.timeout) + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise MCPConnectionError("请求超时") + + async def _send_sse_notification(self, notification: Dict[str, Any]): + """发送通知(无需响应)""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + async with self._session.post(self._endpoint_url, json=notification) as response: + if response.status != 200: + logger.warning(f"通知发送失败: {response.status}") + async def _initialize_modelscope_session(self): """初始化 ModelScope MCP 会话""" init_request = { @@ -107,18 +216,12 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } try: - async with self._session.post( - self.server_url, - json=init_request - ) as response: + async with self._session.post(self.server_url, json=init_request) as response: if response.status != 200: error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") @@ -127,21 +230,16 @@ class SimpleMCPClient: if "error" in init_response: raise MCPConnectionError(f"初始化失败: {init_response['error']}") - # 获取 session ID session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") if session_id: self._session.headers.update({"Mcp-Session-Id": session_id}) - # 发送 initialized 通知 initialized_notification = { "jsonrpc": "2.0", "method": "notifications/initialized" } - async with self._session.post( - self.server_url, - json=initialized_notification - ) as notif_response: + async with self._session.post(self.server_url, json=initialized_notification): pass except aiohttp.ClientError as e: @@ -149,12 +247,18 @@ class SimpleMCPClient: def _build_headers(self) -> Dict[str, str]: """构建请求头""" + # 基础 headers headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream" } - # 添加认证头 + # 合并 connection_config 中的自定义 headers + custom_headers = self.connection_config.get("headers", {}) + if custom_headers: + headers.update(custom_headers) + + # 处理认证配置(认证 headers 优先级更高) auth_config = self.connection_config.get("auth_config", {}) auth_type = self.connection_config.get("auth_type", "none") @@ -178,7 +282,7 @@ class SimpleMCPClient: return headers async def _send_initialize(self): - """发送初始化消息""" + """发送初始化消息(WebSocket)""" init_message = { "jsonrpc": "2.0", "id": self._get_request_id(), @@ -186,124 +290,90 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } await self._websocket.send(json.dumps(init_message)) + response = await self._websocket.recv() + response_data = json.loads(response) - # 等待初始化响应 - response = await asyncio.wait_for( - self._websocket.recv(), - timeout=self.timeout - ) + if "error" in response_data: + raise MCPConnectionError(f"初始化失败: {response_data['error']}") - init_response = json.loads(response) - if "error" in init_response: - raise MCPConnectionError(f"初始化失败: {init_response['error']}") + result = response_data.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + await self._websocket.send(json.dumps({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + })) + + async def list_tools(self) -> List[Dict[str, Any]]: + """获取工具列表""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/list" + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}") + + result = response_data.get("result", {}) + return result.get("tools", []) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """调用工具""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments} + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + error = response_data["error"] + raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response_data.get("result", {}) + + def _get_request_id(self) -> int: + """生成请求 ID""" + self._request_id += 1 + return self._request_id async def _handle_websocket_messages(self): """处理 WebSocket 消息""" try: - while self._websocket and not self._websocket.closed: - try: - message = await self._websocket.recv() - data = json.loads(message) - - # 处理响应 - if "id" in data: - request_id = str(data["id"]) - if request_id in self._pending_requests: - future = self._pending_requests.pop(request_id) - if not future.done(): - future.set_result(data) - - except ConnectionClosed: - break - except Exception as e: - logger.error(f"处理WebSocket消息失败: {e}") - + async for message in self._websocket: + data = json.loads(message) + request_id = data.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(data) + except ConnectionClosed: + logger.info("WebSocket 连接已关闭") except Exception as e: - logger.error(f"WebSocket消息处理异常: {e}") - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments - } - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") - - return response.get("result", {}) - - async def list_tools(self) -> List[Dict[str, Any]]: - """获取工具列表""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/list", - "params": {} - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}") - - result = response.get("result", {}) - return result.get("tools", []) - - async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送WebSocket请求""" - request_id = str(request_data["id"]) - future = asyncio.Future() - self._pending_requests[request_id] = future - - try: - await self._websocket.send(json.dumps(request_data)) - response = await asyncio.wait_for(future, timeout=self.timeout) - return response - except asyncio.TimeoutError: - self._pending_requests.pop(request_id, None) - raise - - async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送HTTP请求""" - try: - async with self._session.post( - self.server_url, - json=request_data - ) as response: - if response.status != 200: - error_text = await response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() - - except aiohttp.ClientError as e: - raise MCPConnectionError(f"HTTP请求失败: {e}") - - def _get_request_id(self) -> str: - """获取请求ID""" - self._request_id += 1 - return f"req_{self._request_id}_{int(time.time() * 1000)}" \ No newline at end of file + logger.error(f"WebSocket 消息处理错误: {e}") diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 67689935..c048f447 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -74,6 +74,7 @@ class WorkflowExecutor: 初始化的工作流状态 """ user_message = input_data.get("message") or "" + conversation_messages = input_data.get("conv_messages") or [] # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) config_variables_list = self.workflow_config.get("variables") or [] @@ -114,7 +115,7 @@ class WorkflowExecutor: } return { - "messages": [('user', user_message)], + "messages": conversation_messages, "variables": variables, "node_outputs": {}, "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index e3bf36c9..72fd0bb5 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -7,13 +7,13 @@ import asyncio import logging from abc import ABC, abstractmethod -from operator import add from typing import Any -from langchain_core.messages import AnyMessage, AIMessage +from langchain_core.messages import AIMessage from langgraph.config import get_stream_writer from typing_extensions import TypedDict, Annotated +from app.core.config import settings from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: Annotated[list[tuple[str, str]], add] + messages: list[dict[str, str]] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list @@ -154,7 +154,7 @@ class BaseNode(ABC): Returns: 超时时间 """ - return 60 + return settings.WORKFLOW_NODE_TIMEOUT # return self.error_handling.get("timeout", 60) async def run(self, state: WorkflowState) -> dict[str, Any]: @@ -203,6 +203,7 @@ class BaseNode(ABC): # 返回包装后的输出和运行时变量 return { **wrapped_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var @@ -356,6 +357,7 @@ class BaseNode(ABC): # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, + "messages": state["messages"], "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6195afbd..0cbd9e8e 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -6,7 +6,6 @@ End 节点实现 import logging import re -import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import NodeType @@ -38,7 +37,23 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) else: + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + ]) output = "工作流已完成" # 统计信息(用于日志) @@ -166,6 +181,12 @@ class EndNode(BaseNode): "chunk_index": 1, "is_suffix": False }) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + } + ]) yield {"__final__": True, "result": output} return @@ -176,7 +197,6 @@ class EndNode(BaseNode): source_node_id = edge.get("source") # Check if the source node is an LLM node for node in self.workflow_config.get("nodes", []): - print("="*50) logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: direct_upstream_llm_nodes.append(source_node_id) @@ -216,12 +236,24 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": output + } + ]) + # yield completion marker yield {"__final__": True, "result": output} return # Has reference to direct upstream LLM node, only output the part after that reference (suffix) - logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") + logger.info( + f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") # Collect suffix parts suffix_parts = [] @@ -258,6 +290,17 @@ class EndNode(BaseNode): # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state, strict=False) + state['messages'].extend([ + { + "role": "user", + "content": self.get_variable("sys.message", state) + }, + { + "role": "assistant", + "content": full_output + } + ]) + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") @@ -280,7 +323,8 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!" + f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") # 统计信息 node_outputs = state.get("node_outputs", {}) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 8498fc38..f65d5879 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -11,12 +11,12 @@ class MessageConfig(BaseModel): """消息配置""" role: str = Field( - ..., + default='user', description="消息角色:system, user, assistant" ) content: str = Field( - ..., + default="", description="消息内容,支持模板变量,如:{{ sys.message }}" ) @@ -30,6 +30,23 @@ class MessageConfig(BaseModel): return v.lower() +class MemoryWindowSetting(BaseModel): + enable: bool = Field( + default=False, + description="启用记忆" + ) + + enable_window: bool = Field( + default=False, + description="启用记忆窗口" + ) + + window_size: int = Field( + default=20, + description="记忆窗口大小" + ) + + class LLMNodeConfig(BaseNodeConfig): """LLM 节点配置 @@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig): description="上下文" ) + memory: MemoryWindowSetting = Field( + ..., + description="对话上下文窗口" + ) + # 简单模式 prompt: str | None = Field( default=None, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bfa1b99f..e25bd35d 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -85,28 +85,31 @@ class LLMNode(BaseNode): """ # 1. 处理消息格式(优先使用 messages) - messages_config = self.config.get("messages") + messages_config = self.typed_config.messages if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: - role = msg_config.get("role", "user").lower() - content_template = msg_config.get("content", "") + role = msg_config.role.lower() + content_template = msg_config.content content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) # 根据角色创建对应的消息对象 if role == "system": - messages.append(SystemMessage(content=content)) + messages.append({"role": "system", "content": content}) elif role in ["user", "human"]: - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) elif role in ["ai", "assistant"]: - messages.append(AIMessage(content=content)) + messages.append({"role": "assistant", "content": content}) else: logger.warning(f"未知的消息角色: {role},默认使用 user") - messages.append(HumanMessage(content=content)) + messages.append({"role": "user", "content": content}) + if self.typed_config.memory.enable: + # if self.typed_config.memory.enable_window: + messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -189,7 +192,7 @@ class LLMNode(BaseNode): return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ - {"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} + {"role": msg.get("role"), "content": msg.get("content", "")} for msg in prompt_or_messages ] if isinstance(prompt_or_messages, list) else None, "config": { diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index f1c99ddb..08a2b280 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -3,8 +3,9 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig -from app.db import get_db_read, get_db_context +from app.db import get_db_read from app.services.memory_agent_service import MemoryAgentService +from app.tasks import write_message_task class MemoryReadNode(BaseNode): @@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: - workspace_id = self.get_variable('sys.workspace_id', state) end_user_id = self.get_variable("sys.user_id", state) - if not workspace_id: - raise RuntimeError("Workspace id is required") if not end_user_id: raise RuntimeError("End user id is required") @@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode): self.typed_config = MemoryWriteNodeConfig(**self.config) async def execute(self, state: WorkflowState) -> Any: - with get_db_context() as db: - workspace_id = self.get_variable('sys.workspace_id', state) - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", state) - if not workspace_id: - raise RuntimeError("Workspace id is required") - if not end_user_id: - raise RuntimeError("End user id is required") + if not end_user_id: + raise RuntimeError("End user id is required") - return await MemoryAgentService().write_memory( - group_id=end_user_id, - message=self._render_template(self.typed_config.message, state), - config_id=str(self.typed_config.config_id), - db=db, - storage_type="neo4j", - user_rag_memory_id="" - ) + write_message_task.delay( + end_user_id, + self._render_template(self.typed_config.message, state), + str(self.typed_config.config_id), + "neo4j", + "" + ) + + return "success" diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 3c00e5a0..35d2e424 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -41,6 +41,7 @@ class ToolConfig(BaseModel): tool_id: Optional[str] = Field(default=None, description="工具ID") operation: Optional[str] = Field(default=None, description="工具特定配置") + class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -348,6 +349,7 @@ class AppChatRequest(BaseModel): variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") + class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") diff --git a/api/app/schemas/memory_episodic_schema.py b/api/app/schemas/memory_episodic_schema.py index 7b3f3d2d..832bf34b 100644 --- a/api/app/schemas/memory_episodic_schema.py +++ b/api/app/schemas/memory_episodic_schema.py @@ -1,9 +1,51 @@ """ 情景记忆的请求和响应模型 """ +from abc import ABC from pydantic import BaseModel, Field from typing import Optional +type_mapping = { + "Person": "人物实体节点", + "Organization": "组织实体节点", + "ORG": "组织实体节点", + "Location": "地点实体节点", + "LOC": "地点实体节点", + "Event": "事件实体节点", + "Concept": "概念实体节点", + "Time": "时间实体节点", + "Position": "职位实体节点", + "WorkRole": "职业实体节点", + "System": "系统实体节点", + "Policy": "政策实体节点", + "HistoricalPeriod": "历史时期实体节点", + "HistoricalState": "历史国家实体节点", + "HistoricalEvent": "历史事件实体节点", + "EconomicFactor": "经济因素实体节点", + "Condition": "条件实体节点", + "Numeric": "数值实体节点" + } +class EmotionType(ABC): + JOY_TYPE = "joy" + SURPRISE_TYPE = "surprise" + SANDROWNESS_TYPE = "sadness" + FEAR_TYPE = "fear" + ANGET_TYPE="anger" + NEUTRAL_TYPE="neutral" + EMOTION_MAPPING={ + "joy":"愉快", + "surprise":"惊喜", + "sadness":"悲伤", + "fear":"恐惧", + "anger":"生气", + "neutral":"中性" + } +class EmotionSubject(ABC): + SUBJECT_MAPPING={ + "self":"自己", + "other":"别人", + "object":"事物对象" + } class EpisodicMemoryOverviewRequest(BaseModel): """情景记忆总览查询请求""" diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 56400c92..bc2d6ca3 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db, get_db_context from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig +from app.schemas import DraftRunRequest from app.services.tool_service import ToolService from app.repositories.tool_repository import ToolRepository from app.db import get_db @@ -59,7 +60,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -210,7 +211,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) + api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -511,7 +512,6 @@ class AppChatService: } ) - except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("多 Agent 流式聊天被中断") @@ -537,83 +537,19 @@ class AppChatService: ) -> Dict[str, Any]: """聊天(非流式)""" workflow_service = WorkflowService(self.db) - - input_data = {"message":message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id + ) + return await workflow_service.run( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, ) - - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - result = await execute_workflow( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ) - - # 更新执行结果 - if result.get("status") == "completed": - workflow_service.update_execution_status( - execution.execution_id, - "completed", - output_data=result.get("node_outputs", {}) - ) - else: - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=result.get("error") - ) - - # 返回增强的响应结构 - return { - "execution_id": execution.execution_id, - "status": result.get("status"), - "output": result.get("output"), # 最终输出(字符串) - "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) - "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID - "error_message": result.get("error"), - "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") - } - - except Exception as e: - logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - raise BusinessException( - code=BizCode.INTERNAL_ERROR, - message=f"工作流执行失败: {str(e)}" - ) async def workflow_chat_stream( self, @@ -622,7 +558,7 @@ class AppChatService: config: WorkflowConfig, app_id: uuid.UUID, workspace_id: uuid.UUID, - user_id: Optional[str] = None, + user_id: str = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, @@ -632,62 +568,21 @@ class AppChatService: ) -> AsyncGenerator[str, None]: """聊天(流式)""" workflow_service = WorkflowService(self.db) - input_data = {"message": message, "variables": variables, - "conversation_id": str(conversation_id)} - inconfig = workflow_service.get_workflow_config(app_id) - # 2. 创建执行记录 - execution = workflow_service.create_execution( - workflow_config_id=inconfig.id, - app_id=app_id, - trigger_type="manual", - triggered_by=None, - conversation_id=conversation_id, - input_data=input_data + payload = DraftRunRequest( + message=message, + variables=variables, + conversation_id=str(conversation_id), + stream=True, + user_id=user_id ) + async for event in workflow_service.run_stream( + app_id=app_id, + payload=payload, + config=config, + workspace_id=workspace_id, + ): + yield event - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 流式执行工作流 - - try: - # 更新状态为运行中 - workflow_service.update_execution_status(execution.execution_id, "running") - - - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in workflow_service._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(workspace_id), - user_id=user_id - ): - # 直接转发 executor 的事件(已经是正确的格式) - yield event - - except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - workflow_service.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - # 发送错误事件 - yield { - "event": "error", - "data": { - "execution_id": execution.execution_id, - "error": str(e) - } - } # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 3695a222..275d6413 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -516,8 +516,16 @@ class ConversationService: conversation_messages = self.get_conversation_history( conversation_id=conversation_id, - max_history=30 + max_history=20 ) + if len(conversation_messages) == 0: + return ConversationOut( + theme="", + question=[], + summary="", + takeaways=[], + info_score=0, + ) with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f: system_prompt = f.read() @@ -536,6 +544,7 @@ class ConversationService: ] logger.info(f"Invoking LLM for conversation_id={conversation_id}") model_resp = await llm.ainvoke(messages) + try: if isinstance(model_resp.content, str): result = json_repair.repair_json(model_resp.content, return_objects=True) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 569684d5..50934226 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -245,7 +245,8 @@ class DraftRunService: storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, - memory: bool = True + memory: bool = True, + sub_agent: bool = False ) -> Dict[str, Any]: """执行试运行(使用 LangChain Agent) @@ -435,7 +436,7 @@ class DraftRunService: elapsed_time = time.time() - start_time # 8. 保存会话消息 - if agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 2d78d796..f0756764 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,7 +9,7 @@ import os import re import time import uuid -from threading import Lock + from typing import Any, AsyncGenerator, Dict, List, Optional import redis @@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - def __init__(self): - self.user_locks: Dict[str, Lock] = {} - self.locks_lock = Lock() + def writer_messages_deal(self,messages,start_time,group_id,config_id,message): messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') @@ -83,12 +81,7 @@ class MemoryAgentService: raise ValueError(f"写入失败: {messages}") - def get_group_lock(self, group_id: str) -> Lock: - """Get lock for specific group to prevent concurrent processing""" - with self.locks_lock: - if group_id not in self.user_locks: - self.user_locks[group_id] = Lock() - return self.user_locks[group_id] + def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" @@ -417,241 +410,236 @@ class MemoryAgentService: except ImportError: audit_logger = None - # Get group lock to prevent concurrent processing - group_lock = self.get_group_lock(group_id) + try: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) + logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + except ConfigurationError as e: + error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" + logger.error(error_msg) - with group_lock: - # Step 1: Load configuration from database only - try: - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config( + # Log failed operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", config_id=config_id, - service_name="MemoryAgentService" + group_id=group_id, + success=False, + duration=duration, + error=error_msg ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") - except ConfigurationError as e: - error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" - logger.error(error_msg) - # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_msg - ) + raise ValueError(error_msg) - raise ValueError(error_msg) + # Step 2: Prepare history + history.append({"role": "user", "content": message}) + logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") - # Step 2: Prepare history - history.append({"role": "user", "content": message}) - logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") + # Step 3: Initialize MCP client and execute read workflow + mcp_config = get_mcp_server_config() + client = MultiServerMCPClient(mcp_config) - # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) + async with client.session('data_flow') as session: + session_start = time.time() + logger.debug("Connected to MCP Server: data_flow") - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates + tools_start = time.time() + tools = await load_mcp_tools(session) + tools_time = time.time() - tools_start + logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() - config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", - config=config - ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + outputs = [] + intermediate_outputs = [] + seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # Pass memory_config to the graph workflow + graph_start = time.time() + async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: + graph_init_time = time.time() - graph_start + logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + start = time.time() + config = {"configurable": {"thread_id": group_id}} + workflow_errors = [] # Track errors from workflow - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] + event_count = 0 + async for event in graph.astream( + {"messages": history, "memory_config": memory_config, "errors": []}, + stream_mode="values", + config=config + ): + event_count += 1 + event_start = time.time() + messages = event.get('messages') + # Capture any errors from the state + if event.get('errors'): + workflow_errors.extend(event.get('errors', [])) + + for msg in messages: + msg_content = msg.content + msg_role = msg.__class__.__name__.lower().replace("message", "") + outputs.append({ + "role": msg_role, + "content": msg_content + }) + + # Extract intermediate outputs + if hasattr(msg, 'content'): + try: + # Handle MCP content format: [{'type': 'text', 'text': '...'}] + content_to_parse = msg_content + if isinstance(msg_content, list): + for block in msg_content: + if isinstance(block, dict) and block.get('type') == 'text': + content_to_parse = block.get('text', '') + break + else: + continue # No text block found + + # Try to parse content as JSON + if isinstance(content_to_parse, str): + try: + parsed = json.loads(content_to_parse) + if isinstance(parsed, dict): + # Check for single intermediate output + if '_intermediate' in parsed: + intermediate_data = parsed['_intermediate'] + output_key = self._create_intermediate_key(intermediate_data) + + if output_key not in seen_intermediates: + seen_intermediates.add(output_key) + intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + + # Check for multiple intermediate outputs (from Retrieve) + if '_intermediates' in parsed: + for intermediate_data in parsed['_intermediates']: output_key = self._create_intermediate_key(intermediate_data) if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + except (json.JSONDecodeError, ValueError): + pass + except Exception as e: + logger.debug(f"Failed to extract intermediate output: {e}") - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + event_time = time.time() - event_start + logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") - - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + workflow_duration = time.time() - start + session_duration = time.time() - session_start + logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") + logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") + logger.info(f"[PERF] Total events processed: {event_count}") + # Extract final answer + final_answer = "" + for messages in outputs: + if messages['role'] == 'tool': + message = messages['content'] - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + # Handle MCP content format: [{'type': 'text', 'text': '...'}] + if isinstance(message, list): + # Extract text from MCP content blocks + for block in message: + if isinstance(block, dict) and block.get('type') == 'text': + message = block.get('text', '') + break + else: + continue # No text block found - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found + try: + parsed = json.loads(message) if isinstance(message, str) else message + if isinstance(parsed, dict): + if parsed.get('status') == 'success': + summary_result = parsed.get('summary_result') + if summary_result: + final_answer = summary_result + except (json.JSONDecodeError, ValueError): + pass - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass + # 记录成功的操作 + total_duration = time.time() - start_time - # 记录成功的操作 - total_duration = time.time() - start_time + # Check for workflow errors + if workflow_errors: + error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) + logger.warning(f"Read workflow completed with errors: {error_details}") - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") - - if audit_logger: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } - ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: + if audit_logger: audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, - success=True, + success=False, duration=total_duration, + error=error_details, details={ "search_switch": search_switch, "history_length": len(history), "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) + "has_answer": bool(final_answer), + "errors": workflow_errors } ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") + + # Raise error if no answer was produced + if not final_answer: + raise ValueError(f"Read workflow failed: {error_details}") + + if audit_logger and not workflow_errors: + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=total_duration, + details={ + "search_switch": search_switch, + "history_length": len(history), + "intermediate_outputs_count": len(intermediate_outputs), + "has_answer": bool(final_answer) + } + ) + retrieved_content=[] + repo = ShortTermMemoryRepository(db) + if str(search_switch)!="2": + for intermediate in intermediate_outputs: + print(intermediate) + intermediate_type=intermediate['type'] + if intermediate_type=="search_result": + query=intermediate['query'] + raw_results=intermediate['raw_results'] + reranked_results=raw_results.get('reranked_results',[]) + try: + statements=[statement['statement'] for statement in reranked_results.get('statements', [])] + except Exception: + statements=[] + statements=list(set(statements)) + retrieved_content.append({query:statements}) + if retrieved_content==[]: + retrieved_content='' + if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] + # 使用 upsert 方法 + repo.upsert( + end_user_id=end_user_id, # 确保这个变量在作用域内 + messages=ori_message, + aimessages=final_answer, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + print("写入成功") - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - + return { + "answer": final_answer, + "intermediate_outputs": intermediate_outputs + } + def _create_intermediate_key(self, output: Dict) -> str: """ Create a unique key for an intermediate output to detect duplicates. diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py index f650217d..eedb7c29 100644 --- a/api/app/services/memory_entity_relationship_service.py +++ b/api/app/services/memory_entity_relationship_service.py @@ -15,6 +15,8 @@ from neo4j.time import DateTime as Neo4jDateTime import json from datetime import datetime +from app.schemas.memory_episodic_schema import EmotionType + logger = logging.getLogger(__name__) class MemoryEntityService: @@ -123,7 +125,7 @@ class MemoryEntityService: extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list) # 合并所有数据并处理相同text的合并 - all_timeline_data = memory_summary_list + statement_list + extracted_entity_list + all_timeline_data = memory_summary_list + statement_list all_timeline_data = self._merge_same_text_items(all_timeline_data) result = { @@ -496,11 +498,11 @@ class MemoryEmotion: length_data.append(emotion_intensity) if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None: # 使用(emotion_type, created_at)作为分组键 - if emotion_type in {"joy", "surprise"}: + if emotion_type in {EmotionType.JOY_TYPE, EmotionType.SURPRISE_TYPE}: emotion_type='positive' - elif emotion_type in {"sadness", "fear", "anger"}: + elif emotion_type in {EmotionType.SANDROWNESS_TYPE, EmotionType.FEAR_TYPE, EmotionType.ANGET_TYPE}: emotion_type='negative' - elif emotion_type=='neutral': + elif emotion_type==EmotionType.NEUTRAL_TYPE: emotion_type='neutral' group_key = (emotion_type, formatted_created_at) # 累加emotion_intensity @@ -595,7 +597,7 @@ class MemoryInteraction: group_id = ori_data[0]['group_id'] Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) if not Space_User: - return '不存在用户' + return [] user_id=Space_User[0]['id'] results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id) diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 8979682d..2db4cdc7 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -267,14 +267,14 @@ class MemoryForgetService: elif node_type_label == 'memorysummary': node_type_label = 'summary' - # 将 Neo4j DateTime 对象转换为时间戳 + # 将 Neo4j DateTime 对象转换为时间戳(毫秒) last_access_time = result['last_access_time'] last_access_dt = convert_neo4j_datetime_to_python(last_access_time) # 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差 if last_access_dt: if last_access_dt.tzinfo is None: last_access_dt = last_access_dt.replace(tzinfo=timezone.utc) - last_access_timestamp = int(last_access_dt.timestamp()) + last_access_timestamp = int(last_access_dt.timestamp() * 1000) else: last_access_timestamp = 0 @@ -520,7 +520,7 @@ class MemoryForgetService: 'average_activation_value': result['average_activation'], 'low_activation_nodes': result['low_activation_nodes'] or 0, 'forgetting_threshold': forgetting_threshold, - 'timestamp': int(datetime.now().timestamp()) + 'timestamp': int(datetime.now().timestamp() * 1000) } else: activation_metrics = { @@ -530,7 +530,7 @@ class MemoryForgetService: 'average_activation_value': None, 'low_activation_nodes': 0, 'forgetting_threshold': forgetting_threshold, - 'timestamp': int(datetime.now().timestamp()) + 'timestamp': int(datetime.now().timestamp() * 1000) } # 收集节点类型分布 @@ -620,7 +620,7 @@ class MemoryForgetService: 'merged_count': record.merged_count, 'average_activation': record.average_activation_value, 'total_nodes': record.total_nodes, - 'execution_time': int(record.execution_time.timestamp()) + 'execution_time': int(record.execution_time.timestamp() * 1000) }) api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据") @@ -661,7 +661,7 @@ class MemoryForgetService: 'node_distribution': node_distribution, 'recent_trends': recent_trends, 'pending_nodes': pending_nodes, - 'timestamp': int(datetime.now().timestamp()) + 'timestamp': int(datetime.now().timestamp() * 1000) } api_logger.info( diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index b0c7a957..1972f344 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -1327,7 +1327,8 @@ class MultiAgentOrchestrator: web_search=web_search, memory=memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + sub_agent=True ) return result diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 59bbc211..9221ab06 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -13,10 +13,15 @@ from typing import Any, Dict, List, Optional, Tuple from app.core.logging_config import get_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context +from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping +from app.services.implicit_memory_service import ImplicitMemoryService from app.services.memory_base_service import MemoryBaseService from app.services.memory_config_service import MemoryConfigService +from app.services.memory_perceptual_service import MemoryPerceptualService +from app.services.memory_short_service import ShortService from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -1196,18 +1201,17 @@ async def analytics_memory_types( end_user_id: Optional[str] = None ) -> List[Dict[str, Any]]: """ - 统计9种记忆类型的数量和百分比 + 统计8种记忆类型的数量和百分比 计算规则: - 1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity - 2. 工作记忆 (WORKING_MEMORY) = chunk + entity - 3. 短期记忆 (SHORT_TERM_MEMORY) = chunk - 4. 长期记忆 (LONG_TERM_MEMORY) = entity - 5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) - 6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity - 7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) - 8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) - 9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) + 1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count + 2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取) + 3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量 + 4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) + 5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一 + 6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) + 7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) + 8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) Args: db: 数据库会话 @@ -1227,7 +1231,6 @@ async def analytics_memory_types( - PERCEPTUAL_MEMORY: 感知记忆 - WORKING_MEMORY: 工作记忆 - SHORT_TERM_MEMORY: 短期记忆 - - LONG_TERM_MEMORY: 长期记忆 - EXPLICIT_MEMORY: 显性记忆 - IMPLICIT_MEMORY: 隐性记忆 - EMOTIONAL_MEMORY: 情绪记忆 @@ -1237,40 +1240,78 @@ async def analytics_memory_types( # 初始化基础服务 base_service = MemoryBaseService() - # 定义需要查询的基础节点类型 - node_types = { - "Statement": "Statement", - "Entity": "ExtractedEntity", - "Chunk": "Chunk" - } + # 初始化感知记忆服务 + perceptual_service = MemoryPerceptualService(db) - # 存储每种节点类型的计数 - node_counts = {} + # 获取感知记忆数量 + if end_user_id: + perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id)) + perceptual_count = perceptual_stats.get("total", 0) + else: + perceptual_count = 0 - # 查询每种节点类型的数量 - for key, node_type in node_types.items(): - if end_user_id: - query = f""" - MATCH (n:{node_type}) + # 获取工作记忆数量(基于会话数量) + work_count = 0 + if end_user_id: + try: + conversation_repo = ConversationRepository(db) + conversations = conversation_repo.get_conversation_by_user_id( + user_id=uuid.UUID(end_user_id), + limit=100, # 获取更多会话以准确统计 + is_activate=True + ) + work_count = len(conversations) + logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}") + work_count = 0 + + # 获取隐性记忆数量(基于 Statement 节点数量的三分之一) + implicit_count = 0 + if end_user_id: + try: + # 查询 Statement 节点数量 + query = """ + MATCH (n:Statement) WHERE n.group_id = $group_id RETURN count(n) as count """ result = await _neo4j_connector.execute_query(query, group_id=end_user_id) - else: - query = f""" - MATCH (n:{node_type}) - RETURN count(n) as count - """ - result = await _neo4j_connector.execute_query(query) - - # 提取计数结果 - count = result[0]["count"] if result and len(result) > 0 else 0 - node_counts[key] = count + statement_count = result[0]["count"] if result and len(result) > 0 else 0 + # 取三分之一作为隐性记忆数量 + implicit_count = round(statement_count / 3) + logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}") + implicit_count = 0 - # 获取各节点类型的数量 - statement_count = node_counts.get("Statement", 0) - entity_count = node_counts.get("Entity", 0) - chunk_count = node_counts.get("Chunk", 0) + # 原有的基于行为习惯的统计方式(已注释) + # implicit_count = 0 + # if end_user_id: + # try: + # implicit_service = ImplicitMemoryService(db, end_user_id) + # behavior_habits = await implicit_service.get_behavior_habits( + # user_id=end_user_id + # ) + # implicit_count = len(behavior_habits) + # logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})") + # except Exception as e: + # logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}") + # implicit_count = 0 + + # 获取短期记忆数量(基于 /short_term 接口返回的问答对数量) + short_term_count = 0 + if end_user_id: + try: + short_term_service = ShortService(end_user_id) + short_term_data = short_term_service.get_short_databasets() + # 统计 short_term 数组的长度 + if short_term_data: + short_term_count = len(short_term_data) + logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})") + except Exception as e: + logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}") + short_term_count = 0 # 获取用户的遗忘阈值配置 forgetting_threshold = 0.3 # 默认值 @@ -1296,17 +1337,16 @@ async def analytics_memory_types( # 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量 episodic_count = await base_service.get_episodic_memory_count(end_user_id) explicit_count = await base_service.get_explicit_memory_count(end_user_id) - emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count) + emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count) forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold) - # 按规则计算9种记忆类型的数量(使用英文枚举作为key) + # 按规则计算8种记忆类型的数量(使用英文枚举作为key) memory_counts = { - "PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆 - "WORKING_MEMORY": chunk_count + entity_count, # 工作记忆 - "SHORT_TERM_MEMORY": chunk_count, # 短期记忆 - "LONG_TERM_MEMORY": entity_count, # 长期记忆 + "PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆 + "WORKING_MEMORY": work_count, # 工作记忆(基于会话数量) + "SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量) "EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆) - "IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity) + "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3) "EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计) "EPISODIC_MEMORY": episodic_count, # 情景记忆 "FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值) @@ -1332,7 +1372,7 @@ async def analytics_graph_data( db: Session, end_user_id: str, node_types: Optional[List[str]] = None, - limit: int = 100, + limit: int = 130, depth: int = 1, center_node_id: Optional[str] = None ) -> Dict[str, Any]: @@ -1416,12 +1456,14 @@ async def analytics_graph_data( elementId(n) as id, labels(n)[0] as label, properties(n) as properties + LIMIT $limit """ node_params = { "group_id": end_user_id, - # "limit": limit + "limit": limit } + # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1576,10 +1618,15 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ for field in allowed_fields: if field in properties: value = properties[field] + if str(field) == 'entity_type': + value=type_mapping.get(value,'') + if str(field)=="emotion_type": + value=EmotionType.EMOTION_MAPPING.get(value) + if str(field)=="emotion_subject": + value=EmotionSubject.SUBJECT_MAPPING.get(value) # 清理 Neo4j 特殊类型 filtered_props[field] = _clean_neo4j_value(value) filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0] - print(filtered_props) return filtered_props diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 7d3c784f..974d5418 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -2,29 +2,28 @@ 工作流服务层 """ import datetime -import json import logging import uuid -import datetime from typing import Any, Annotated, AsyncGenerator +from deprecated import deprecated from fastapi import Depends from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config -from app.db import get_db, get_db_context +from app.db import get_db +from app.models.conversation_model import Message from app.models.workflow_model import WorkflowConfig, WorkflowExecution -from app.repositories.end_user_repository import EndUserRepository -from app.services.multi_agent_service import convert_uuids_to_str +from app.repositories.conversation_repository import MessageRepository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) from app.schemas import DraftRunRequest -from app.utils.sse_utils import format_sse_message +from app.services.multi_agent_service import convert_uuids_to_str logger = logging.getLogger(__name__) @@ -37,6 +36,7 @@ class WorkflowService: self.config_repo = WorkflowConfigRepository(db) self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) + self.message_repo = MessageRepository(db) # ==================== 配置管理 ==================== @@ -418,14 +418,13 @@ class WorkflowService: """运行工作流 Args: + workspace_id: + config: + payload: app_id: 应用 ID - input_data: 输入数据(包含 message 和 variables) - triggered_by: 触发用户 ID - conversation_id: 会话 ID(可选) - stream: 是否流式返回 Returns: - 执行结果(非流式)或生成器(流式) + 执行结果(非流式) Raises: BusinessException: 配置不存在或执行失败时抛出 @@ -438,7 +437,8 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id} # 转换 user_id 为 UUID triggered_by_uuid = None @@ -461,7 +461,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -482,14 +482,6 @@ class WorkflowService: try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - with get_db_context() as db: - end_user_repo = EndUserRepository(db) - new_end_user = end_user_repo.get_or_create_end_user( - app_id=app_id, - other_id=payload.user_id, - original_user_id=payload.user_id # Save original user_id to other_id - ) - end_user_id = str(new_end_user.id) executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -500,14 +492,17 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=end_user_id + user_id=payload.user_id ) # 更新执行结果 @@ -517,6 +512,17 @@ class WorkflowService: "completed", output_data=result ) + final_messages = result.get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: self.update_execution_status( execution.execution_id, @@ -529,6 +535,7 @@ class WorkflowService: "execution_id": execution.execution_id, "status": result.get("status"), "variables": result.get("variables"), + "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID @@ -559,6 +566,7 @@ class WorkflowService: """运行工作流(流式) Args: + workspace_id: app_id: 应用 ID payload: 请求对象(包含 message, variables, conversation_id 等) config: 存储类型(可选) @@ -601,7 +609,7 @@ class WorkflowService: workflow_config_id=config.id, app_id=app_id, trigger_type="manual", - triggered_by=triggered_by_uuid, + triggered_by=None, conversation_id=conversation_id_uuid, input_data=input_data ) @@ -621,14 +629,6 @@ class WorkflowService: try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - with get_db_context() as db: - end_user_repo = EndUserRepository(db) - new_end_user = end_user_repo.get_or_create_end_user( - app_id=app_id, - other_id=payload.user_id, - original_user_id=payload.user_id # Save original user_id to other_id - ) - end_user_id = str(new_end_user.id) executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) for exec_res in executions: @@ -638,17 +638,46 @@ class WorkflowService: variables = last_state.get("variables", {}) conv_vars = variables.get("conv", {}) input_data["conv"] = conv_vars + input_data["conv_messages"] = last_state.get("messages") or [] break + init_message_length = len(input_data.get("conv_messages", [])) + from app.core.workflow.executor import execute_workflow_stream - # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) - async for event in self._run_workflow_stream( + async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=end_user_id + user_id=payload.user_id ): - # 直接转发 executor 的事件(已经是正确的格式) + if event.get("event") == "workflow_end": + + status = event.get("data", {}).get("status") + if status == "completed": + self.update_execution_status( + execution.execution_id, + "completed", + output_data=event.get("data") + ) + final_messages = event.get("data", {}).get("messages", [])[init_message_length:] + for message in final_messages: + message_obj = Message( + conversation_id=conversation_id_uuid, + role=message["role"], + content=message["content"], + ) + self.message_repo.add_message(message_obj) + self.db.commit() + logger.info(f"Workflow Run Success, " + f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") + elif status == "failed": + self.update_execution_status( + execution.execution_id, + "failed", + output_data=event.get("data") + ) + else: + logger.error(f"unexpect workflow run status, status: {status}") yield event except Exception as e: @@ -667,6 +696,8 @@ class WorkflowService: } } + @deprecated(reason="This method is deprecated. " + "Please use WorkflowService.run / run_stream instead.") async def run_workflow( self, app_id: uuid.UUID, @@ -819,6 +850,7 @@ class WorkflowService: return clean_value(event) + @deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.") async def _run_workflow_stream( self, workflow_config: dict[str, Any], diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 834d22af..4a35a4cc 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -8,9 +8,11 @@ import uuid from typing import Dict, Any, Optional, Union from datetime import datetime +from app.db import get_db_read from app.models import AppRelease, WorkflowConfig from app.models.agent_app_config_model import AgentConfig from app.models.multi_agent_model import MultiAgentConfig +from app.repositories.workflow_repository import WorkflowConfigRepository def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]: @@ -24,18 +26,18 @@ def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]: """ if model_parameters is None: return None - + if isinstance(model_parameters, dict): return model_parameters - + # Pydantic v2 if hasattr(model_parameters, 'model_dump'): return model_parameters.model_dump() - + # Pydantic v1 if hasattr(model_parameters, 'dict'): return model_parameters.dict() - + # 其他情况尝试转换 try: return dict(model_parameters) @@ -54,17 +56,18 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]: """ if data is None: return None - + from app.schemas import ModelParameters - + if isinstance(data, ModelParameters): return data - + if isinstance(data, dict): return ModelParameters(**data) - + return None + class AgentConfigProxy: """Proxy class for AgentConfig (legacy compatibility)""" @@ -78,8 +81,7 @@ class AgentConfigProxy: self.default_model_config_id = release.default_model_config_id -def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: - +def agent_config_4_app_release(release: AppRelease) -> AgentConfig: config_dict = release.config agent_config = AgentConfig( @@ -95,18 +97,17 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: return agent_config -def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: +def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig: config_dict = release.config - agent_config = MultiAgentConfig( app_id=release.app_id, default_model_config_id=release.default_model_config_id, model_parameters=config_dict.get("model_parameters"), master_agent_id=config_dict.get("master_agent_id"), master_agent_name=config_dict.get("master_agent_name"), - orchestration_mode=config_dict.get("orchestration_mode", "conditional"), + orchestration_mode=config_dict.get("orchestration_mode", "supervisor"), sub_agents=config_dict.get("sub_agents", []), routing_rules=config_dict.get("routing_rules"), execution_config=config_dict.get("execution_config", {}), @@ -116,24 +117,26 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: return agent_config -def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig: +def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig: config_dict = release.config - + with get_db_read() as db: + source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id) + source_config_id = source_config.id config = WorkflowConfig( - id=release.id, + id=source_config_id, app_id=release.app_id, nodes=config_dict.get("nodes", []), edges=config_dict.get("edges", []), variables=config_dict.get("variables", []), execution_config=config_dict.get("execution_config", {}), triggers=config_dict.get("triggers", []) - ) return config + def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None): """Convert dict to MultiAgentConfig model object @@ -149,7 +152,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui ... "app_id": "uuid-here", ... "master_agent_id": "master-uuid", ... "master_agent_name": "Master Agent", - ... "orchestration_mode": "conditional", + ... "orchestration_mode": "supervisor", ... "sub_agents": [ ... {"agent_id": "sub1-uuid", "name": "Sub Agent 1", "role": "specialist", "priority": 1}, ... {"agent_id": "sub2-uuid", "name": "Sub Agent 2", "role": "specialist", "priority": 2} @@ -186,7 +189,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui app_id=final_app_id, master_agent_id=master_agent_id, master_agent_name=config_dict.get("master_agent_name"), - orchestration_mode=config_dict.get("orchestration_mode", "conditional"), + orchestration_mode=config_dict.get("orchestration_mode", "supervisor"), sub_agents=config_dict.get("sub_agents", []), routing_rules=config_dict.get("routing_rules"), execution_config=config_dict.get("execution_config", {}), @@ -276,7 +279,8 @@ def agent_config_to_dict(agent_config) -> Dict[str, Any]: "id": str(agent_config.id), "app_id": str(agent_config.app_id), "system_prompt": agent_config.system_prompt, - "default_model_config_id": str(agent_config.default_model_config_id) if agent_config.default_model_config_id else None, + "default_model_config_id": str( + agent_config.default_model_config_id) if agent_config.default_model_config_id else None, "model_parameters": agent_config.model_parameters, "knowledge_retrieval": agent_config.knowledge_retrieval, "memory": agent_config.memory, @@ -338,6 +342,3 @@ def workflow_config_to_dict(workflow_config) -> Dict[str, Any]: "created_at": workflow_config.created_at.isoformat() if workflow_config.created_at else None, "updated_at": workflow_config.updated_at.isoformat() if workflow_config.updated_at else None } - - - diff --git a/api/pyproject.toml b/api/pyproject.toml index 2dcc706d..6da684de 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -136,7 +136,8 @@ dependencies = [ "markdown-to-json==2.1.1", "valkey==6.0.2", "python-calamine>=0.4.0", - "xlrd==2.0.2" + "xlrd==2.0.2", + "deprecated>=1.3.1", ] [tool.pytest.ini_options] diff --git a/web/src/assets/images/empty/chatEmpty.png b/web/src/assets/images/empty/chatEmpty.png new file mode 100644 index 00000000..8ce1f719 Binary files /dev/null and b/web/src/assets/images/empty/chatEmpty.png differ diff --git a/web/src/assets/images/menu/helpCenter.svg b/web/src/assets/images/menu/helpCenter.svg new file mode 100644 index 00000000..504e309c --- /dev/null +++ b/web/src/assets/images/menu/helpCenter.svg @@ -0,0 +1,14 @@ + + + 使用帮助备份 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/helpCenter_active.svg b/web/src/assets/images/menu/helpCenter_active.svg new file mode 100644 index 00000000..2840c421 --- /dev/null +++ b/web/src/assets/images/menu/helpCenter_active.svg @@ -0,0 +1,14 @@ + + + 使用帮助 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 11ccb5c3..c90f9208 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -55,7 +55,7 @@ const ChatContent: FC = ({ } {/* 消息气泡框 */} -
= ({
{/* 底部标签(如时间戳、用户名等) */} {labelPosition === 'bottom' && -
+
{labelFormat(item)}
} diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 0f2543ab..cad52e1e 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -71,6 +71,7 @@ export const en = { stepTwoDescription: 'Here you can create and manage spaces to organize models and data for different use cases.Once your spaces are ready, head to User Management to invite members and manage access.👉 Click User Management in the left menu to continue.', stepThree: 'This is User Management', stepThreeDescription: 'Here you can create users, assign roles, and manage access for your team.Once users are set up, the basic configuration is complete and you’re ready to start using the platform 🎉', + finishButtonText: 'Get Started', }, menu: { home: 'Home', @@ -91,6 +92,7 @@ export const en = { memberManagement: 'Member Management', memorySummary: 'Memory Summary', memoryConversation: 'Memory Validation', + helpCenter: 'Help Center', memorySummaryHandlers: 'Memory Summary Handlers', createMemorySummary: 'Create Memory Summary', memoryManagement: 'Memory Management', @@ -183,14 +185,15 @@ export const en = { createNewMemorySummary: 'Create New Memory Entry', createNewApplication: 'Create New Application', - createNewApplicationDesc: 'Create a new application for this space', + createNewApplicationDesc: 'Build an app in just 3 minutes with zero-code drag-and-drop.', createNewKnowledge: 'Create New Knowledge', - createNewKnowledgeDesc: 'Create a new memory entry', + createNewKnowledgeDesc: 'Transform your data into a fully searchable, dedicated knowledge base in seconds.', memoryConversation: 'Memory Conversation', - memoryConversationDesc: 'Create a new memory conversation', - + memoryConversationDesc: 'The more you use it, the better AI understands you.', + helpCenter: 'Help Center', + helpCenterDesc: 'One-stop support to answer your questions and get you started fast.', memorySummary: 'View Memory Summary', memorySummaryDesc: 'View Memory Summary Report', @@ -618,6 +621,7 @@ export const en = { retrieve:'Retrieve', processing: 'Processing', processingMode: 'Processing Mode', + processMsg: 'Processing Message', dataSize: 'Data Size', createUpdateTime: 'Create/Update Time', operation: 'Operation', @@ -1449,6 +1453,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re }, memoryConversation: { searchPlaceholder: 'Enter user ID...', + chatEmpty:'Is there anything I can help you with?', userID: 'User ID', testMemoryConversation: 'Test Memory Conversation', conversationContent: 'Conversation Content', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index b1a245d3..028202d1 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -71,6 +71,7 @@ export const zh = { stepTwoDescription: '你可以在这里创建和管理不同的空间,把模型和数据组织到具体的使用场景中。空间创建完成后,可以去 User Management 邀请成员、分配权限,一起协作使用。👉 点击左侧 User Management 继续。', stepThree: '这里是用户管理页', stepThreeDescription: '你可以在这里创建用户、分配角色,并管理团队成员的访问权限。完成用户设置后,基础配置就准备好了,可以开始实际使用平台的各项功能了 🎉', + finishButtonText: '开始使用', }, menu: { home: '首页', @@ -782,14 +783,15 @@ export const zh = { createNewMemorySummary: '创建新记忆条目', createNewApplication: '创建新应用', - createNewApplicationDesc: '创建新空间应用', + createNewApplicationDesc: '零代码拖拽3分钟创应用', - createNewKnowledge: '创建新知识', - createNewKnowledgeDesc: '创建新记忆条目', + createNewKnowledge: '创建知识库', + createNewKnowledgeDesc: '秒变可搜索的专属知识库', memoryConversation: '记忆对话', - memoryConversationDesc: '记忆对话', - + memoryConversationDesc: '让AI越用越懂你', + helpCenter: '帮助中心', + helpCenterDesc: '一站式解决疑问快速上手', memorySummary: '查看记忆摘要', memorySummaryDesc: '查看记忆摘要报告', @@ -1524,6 +1526,7 @@ export const zh = { deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体' }, memoryConversation: { + chatEmpty:'有什么我可以帮您的吗?', searchPlaceholder: '输入用户ID...', userID: '用户ID', testMemoryConversation: '测试记忆对话', diff --git a/web/src/store/locale.ts b/web/src/store/locale.ts index e0393b9d..4fbd79ed 100644 --- a/web/src/store/locale.ts +++ b/web/src/store/locale.ts @@ -1,3 +1,11 @@ +/* + * @Description: + * @Version: 0.0.1 + * @Author: yujiangping + * @Date: 2026-01-05 17:22:23 + * @LastEditors: yujiangping + * @LastEditTime: 2026-01-15 21:02:43 + */ import { create } from 'zustand' import enUS from 'antd/locale/en_US'; import zhCN from 'antd/locale/zh_CN'; @@ -12,6 +20,28 @@ import { timezoneToAntdLocaleMap } from '@/utils/timezones'; dayjs.extend(utc); dayjs.extend(timezone); +// 自定义中文 locale,修改 Tour 组件的按钮文字 +const customZhCN: Locale = { + ...zhCN, + Tour: { + ...zhCN.Tour, + Next: '下一步', + Previous: '上一步', + Finish: '立即体验', + }, +}; + +// 自定义英文 locale,修改 Tour 组件的按钮文字 +const customEnUS: Locale = { + ...enUS, + Tour: { + ...enUS.Tour, + Next: 'Next', + Previous: 'Previous', + Finish: 'Try it now', + }, +}; + interface I18nState { language: string; @@ -23,7 +53,7 @@ interface I18nState { const initialTimeZone = localStorage.getItem('timeZone') || 'Asia/Shanghai' const initialLanguage = localStorage.getItem('language') || 'en' -const initialLocale = initialLanguage === 'en' ? enUS : zhCN +const initialLocale = initialLanguage === 'en' ? customEnUS : customZhCN i18n.changeLanguage(initialLanguage) export const useI18n = create((set, get) => ({ @@ -32,7 +62,7 @@ export const useI18n = create((set, get) => ({ timeZone: initialTimeZone, changeLanguage: (language: string) => { i18n.changeLanguage(language) - const localeName = timezoneToAntdLocaleMap[language] || enUS; + const localeName = language === 'en' ? customEnUS : customZhCN; set({ language: language, locale: localeName }) }, changeTimeZone: (timeZone: string) => { diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index d385e1f0..12d17cda 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -11,6 +11,7 @@ import Empty from '@/components/Empty' import { formatDateTime } from '@/utils/format'; import { randomString } from '@/utils/common' import BgImg from '@/assets/images/conversation/bg.png' +import ChatEmpty from '@/assets/images/empty/chatEmpty.png' import Chat from '@/components/Chat' import type { ChatItem } from '@/components/Chat/types' import ButtonCheckbox from '@/components/ButtonCheckbox' @@ -259,9 +260,10 @@ const Conversation: FC = () => {
+
} - contentClassName="rb:h-[calc(100%-152px)]" + empty={} + contentClassName="rb:h-[calc(100%-152px)] " data={chatList} streamLoading={streamLoading} loading={loading} @@ -290,6 +292,7 @@ const Conversation: FC = () => { +
) diff --git a/web/src/views/Home/components/QuickOperation.tsx b/web/src/views/Home/components/QuickOperation.tsx index 892dd8a0..d894417a 100644 --- a/web/src/views/Home/components/QuickOperation.tsx +++ b/web/src/views/Home/components/QuickOperation.tsx @@ -1,3 +1,11 @@ +/* + * @Description: + * @Version: 0.0.1 + * @Author: yujiangping + * @Date: 2026-01-05 17:22:23 + * @LastEditors: yujiangping + * @LastEditTime: 2026-01-15 14:55:51 + */ import { type FC } from 'react' import { useTranslation } from 'react-i18next' import { useNavigate } from 'react-router-dom'; @@ -5,33 +13,49 @@ import Card from './Card'; import applicationIcon from '@/assets/images/menu/application_active.svg'; import knowledgeIcon from '@/assets/images/menu/knowledge_active.svg'; import memoryConversationIcon from '@/assets/images/menu/memoryConversation_active.svg'; +import helpCenterIcon from '@/assets/images/menu/helpCenter_active.svg' import arrowTopRight from '@/assets/images/home/arrow_top_right.svg'; const quickOperations = [ { key: 'createNewApplication', url: '/application' }, { key: 'createNewKnowledge', url: '/knowledge-base' }, { key: 'memoryConversation', url: '/memory-conversation' }, + { key: 'helpCenter', url: '' }, ] const quickOperationIcons: {[key: string]: string | undefined} = { createNewApplication: applicationIcon, createNewKnowledge: knowledgeIcon, memoryConversation: memoryConversationIcon, + helpCenter: helpCenterIcon } const QuickOperation:FC = () => { - const { t } = useTranslation() + const { t, i18n } = useTranslation() const navigate = useNavigate(); const handleJump = (url: string | null) => { if (url) { navigate(url) + }else{ + const currentLang = i18n.language; + const lang = currentLang === 'zh' ? 'zh' : 'en'; + const helpUrl = `https://docs.redbearai.com/s/${lang}-memorybear`; + + // 创建隐藏的 a 标签来避免弹窗拦截 + const link = document.createElement('a'); + link.href = helpUrl; + link.target = '_blank'; + link.rel = 'noopener noreferrer'; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); } } return ( -
+
{quickOperations.map(item => (
handleJump(item.url)}>
diff --git a/web/src/views/Index/components/GuideCard.tsx b/web/src/views/Index/components/GuideCard.tsx index d60eae36..a8560136 100644 --- a/web/src/views/Index/components/GuideCard.tsx +++ b/web/src/views/Index/components/GuideCard.tsx @@ -1,3 +1,11 @@ +/* + * @Description: + * @Version: 0.0.1 + * @Author: yujiangping + * @Date: 2026-01-13 11:44:06 + * @LastEditors: yujiangping + * @LastEditTime: 2026-01-15 20:59:57 + */ import React, { useState, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { useNavigate } from 'react-router-dom'; diff --git a/web/src/views/Index/components/QuickActions.tsx b/web/src/views/Index/components/QuickActions.tsx index edf5166e..063014df 100644 --- a/web/src/views/Index/components/QuickActions.tsx +++ b/web/src/views/Index/components/QuickActions.tsx @@ -47,7 +47,7 @@ const QuickActions: FC = ({ onNavigate }) => { key: 'space-management', icon: spaceIcon, title: t('quickActions.spaceManagement'), - onClick: () => onNavigate?.('/spce') + onClick: () => onNavigate?.('/space') }, // { // key: 'workflow-orchestration', diff --git a/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx b/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx index 8087e596..382deac0 100644 --- a/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx +++ b/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx @@ -2,7 +2,7 @@ import { useEffect, useState, useRef, useCallback, type FC } from 'react'; import { useNavigate, useParams, useLocation } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; -import { Switch, Button, Dropdown, Space, Modal, message, Radio } from 'antd'; +import { Switch, Button, Dropdown, Space, Modal, message, Radio, Tooltip } from 'antd'; import type { MenuProps } from 'antd'; import SearchInput from '@/components/SearchInput' import Table, { type TableRef } from '@/components/Table' @@ -564,6 +564,37 @@ const Private: FC = () => { ); } + },{ + title: t('knowledgeBase.processMsg'), + dataIndex: 'progress_msg', + key: 'progress_msg', + width: 320, + render: (value: string) => { + if (!value) return '-'; + + // 解析日志格式,将 \n 转换为换行 + const formattedText = value.replace(/\\n/g, '\n'); + + return ( + {formattedText}} placement="topLeft"> +
+ {formattedText} +
+
+ ); + } }, { title: t('knowledgeBase.processingMode'), diff --git a/web/src/views/KnowledgeBase/components/KnowledgeGraph.tsx b/web/src/views/KnowledgeBase/components/KnowledgeGraph.tsx index 8ec367c5..55f8fcfc 100644 --- a/web/src/views/KnowledgeBase/components/KnowledgeGraph.tsx +++ b/web/src/views/KnowledgeBase/components/KnowledgeGraph.tsx @@ -292,7 +292,7 @@ const KnowledgeGraph: FC = ({ data, loading = false }) => { if (params.dataType === 'node') { const node = params.data as KnowledgeNode return ` -
+
${node.entity_name}
类型: ${node.entity_type}
重要度: ${(node.pagerank * 100).toFixed(2)}%
@@ -301,10 +301,10 @@ const KnowledgeGraph: FC = ({ data, loading = false }) => { } else if (params.dataType === 'edge') { const edge = params.data as KnowledgeEdge return ` -
+
关系
权重: ${edge.weight}
-
${edge.description}
+
${edge.description}
` }