diff --git a/api/LICENSE b/LICENSE similarity index 100% rename from api/LICENSE rename to LICENSE diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 807c59f4..58c89f8f 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,5 +1,6 @@ import os import platform +import re from datetime import timedelta from urllib.parse import quote @@ -11,21 +12,24 @@ from app.core.logging_config import get_logger logger = get_logger(__name__) + +def _mask_url(url: str) -> str: + """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" + return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) + # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') # 创建 Celery 应用实例 -# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定) -# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定) +# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议), +# 未配置则回退到 Redis 方案 +# backend: 结果存储(使用 Redis) # NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md -# Build canonical broker/backend URLs and force them into os.environ so that -# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first) -# cannot be overridden by stray env vars. -# See: https://github.com/celery/celery/issues/4284 -_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" +_broker_url = os.getenv("CELERY_BROKER_URL") or \ + f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url @@ -45,8 +49,8 @@ celery_app = Celery( logger.info( "Celery app initialized", extra={ - "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"), - "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"), + "broker": _mask_url(_broker_url), + "backend": _mask_url(_backend_url), }, ) # Default queue for unrouted tasks @@ -77,6 +81,7 @@ celery_app.conf.update( # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution + worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING # 结果过期时间 result_expires=3600, # 结果保存1小时 diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index e9b539df..3ba9c3a9 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,6 +57,7 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, + api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -65,10 +66,25 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 + - 当提供 api_key 参数时,查找该 API Key 关联的应用 """ + from sqlalchemy import select as sa_select + from app.models.api_key_model import ApiKey + workspace_id = current_user.current_workspace_id service = app_service.AppService(db) + # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 + if api_key: + matched_id = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key == api_key, + ApiKey.resource_id.isnot(None), + ) + ).scalar_one_or_none() + ids = str(matched_id) if matched_id else "" + # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index 55149cce..4e1ba74c 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -14,6 +14,9 @@ Routes: import os import uuid from typing import Any +import httpx +import mimetypes +from urllib.parse import urlparse, unquote from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse @@ -290,6 +293,101 @@ async def upload_file_with_share_token( ) +@router.get("/files/info-by-url", response_model=ApiResponse) +async def get_file_info_by_url( + url: str, +): + """ + Get file information by network URL (no authentication required). + + Fetches file metadata from a remote URL via HTTP HEAD request. + Falls back to GET request if HEAD is not supported. + Returns file type, name, and size. + + Args: + url: The network URL of the file. + + Returns: + ApiResponse with file information. + """ + api_logger.info(f"File info by URL request: url={url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Try HEAD request first + response = await client.head(url, follow_redirects=True) + + # If HEAD fails, try GET request (some servers don't support HEAD) + if response.status_code != 200: + api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request") + response = await client.get(url, follow_redirects=True) + + if response.status_code != 200: + api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access file: HTTP {response.status_code}" + ) + + # Get file size from Content-Length header or actual content + file_size = response.headers.get("Content-Length") + if file_size: + file_size = int(file_size) + elif hasattr(response, 'content'): + file_size = len(response.content) + else: + file_size = None + + # Get content type from Content-Type header + content_type = response.headers.get("Content-Type", "application/octet-stream") + # Remove charset and other parameters from content type + content_type = content_type.split(';')[0].strip() + + # Extract filename from Content-Disposition or URL + file_name = None + content_disposition = response.headers.get("Content-Disposition") + if content_disposition and "filename=" in content_disposition: + parts = content_disposition.split("filename=") + if len(parts) > 1: + file_name = parts[1].strip('"').strip("'") + + if not file_name: + parsed_url = urlparse(url) + file_name = unquote(os.path.basename(parsed_url.path)) or "unknown" + + # Extract file extension from filename + _, file_ext = os.path.splitext(file_name) + + # If no extension found, infer from content type + if not file_ext: + ext = mimetypes.guess_extension(content_type) + if ext: + file_ext = ext + file_name = f"{file_name}{file_ext}" + + api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}") + + return success( + data={ + "url": url, + "file_name": file_name, + "file_ext": file_ext.lower() if file_ext else "", + "file_size": file_size, + "content_type": content_type, + }, + msg="File information retrieved successfully" + ) + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Unexpected error: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve file information: {str(e)}" + ) + + @router.get("/files/{file_id}", response_model=Any) async def download_file( request: Request, @@ -476,8 +574,12 @@ async def get_file_url( # For local storage, generate signed URL with expiration url = generate_signed_url(str(file_id), expires) else: - # For remote storage (OSS/S3), get presigned URL - url = await storage_service.get_file_url(file_key, expires=expires) + # For remote storage (OSS/S3), get presigned URL with forced download + url = await storage_service.get_file_url( + file_key, + expires=expires, + file_name=file_metadata.file_name, + ) url = _match_scheme(request, url) api_logger.info(f"Generated file URL: file_id={file_id}") @@ -688,7 +790,7 @@ async def permanent_download_file( # For remote storage, redirect to presigned URL with long expiration try: # Use a very long expiration (7 days max for most cloud providers) - presigned_url = await storage_service.get_file_url(file_key, expires=604800) + presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name) presigned_url = _match_scheme(request, presigned_url) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except Exception as e: @@ -697,3 +799,44 @@ async def permanent_download_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve file: {str(e)}" ) + + +@router.get("/files/{file_id}/status", response_model=ApiResponse) +async def get_file_status( + file_id: uuid.UUID, + db: Session = Depends(get_db), +): + """ + Get file upload/processing status (no authentication required). + + This endpoint is used to check if a file (e.g., TTS audio) is ready. + Returns status: pending, completed, or failed. + + Args: + file_id: The UUID of the file. + db: Database session. + + Returns: + ApiResponse with file status and metadata. + """ + api_logger.info(f"File status request: file_id={file_id}") + + # Query file metadata from database + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + api_logger.warning(f"File not found in database: file_id={file_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The file does not exist" + ) + + return success( + data={ + "file_id": str(file_id), + "status": file_metadata.status, + "file_name": file_metadata.file_name, + "file_size": file_metadata.file_size, + "content_type": file_metadata.content_type, + }, + msg="File status retrieved successfully" + ) diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 0f2da3b0..6f27d87a 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -91,9 +91,11 @@ async def get_mcp_servers( try: cookies = api.get_cookies(token) + headers=api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -173,6 +175,7 @@ async def get_operational_mcp_servers( url = f'{api.mcp_base_url}/operational' headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' try: cookies = api.get_cookies(access_token=token, cookies_required=True) @@ -260,7 +263,9 @@ async def create_mcp_market_config( api.login(create_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(create_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {create_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") @@ -290,9 +295,11 @@ async def create_mcp_market_config( 'search': "" } cookies = api.get_cookies(token) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -393,7 +400,9 @@ async def update_mcp_market_config( api.login(update_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(update_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {update_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index e3d2bf92..aa4d48e3 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -118,142 +118,142 @@ async def download_log( return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) -@router.post("/writer_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Write service endpoint - processes write operations synchronously - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Response with write operation status - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - - # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id - if storage_type == 'rag': - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - else: - api_logger.warning( - f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") - storage_type = 'neo4j' - else: - api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") - storage_type = 'neo4j' - - api_logger.info( - f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") - try: - messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.end_user_id, - messages_list, - config_id, - db, - storage_type, - user_rag_memory_id, - language - ) - - return success(data=result, msg="写入成功") - except BaseException as e: - # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup - if hasattr(e, 'exceptions'): - error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] - detailed_error = "; ".join(error_messages) - api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) - api_logger.error(f"Write operation error: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/writer_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Async write service endpoint - enqueues write processing to Celery - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Task ID for tracking async operation - Use GET /memory/write_result/{task_id} to check task status and get result - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info( - f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - - task = celery_app.send_task( - "app.core.memory.agent.write_message", - args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] - ) - api_logger.info(f"Write task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="写入任务已提交") - except Exception as e: - api_logger.error(f"Async write operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# @router.post("/writer_service", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Write service endpoint - processes write operations synchronously +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Response with write operation status +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# +# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id +# if storage_type == 'rag': +# if workspace_id: +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: +# user_rag_memory_id = str(knowledge.id) +# else: +# api_logger.warning( +# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") +# storage_type = 'neo4j' +# else: +# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") +# storage_type = 'neo4j' +# +# api_logger.info( +# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") +# try: +# messages_list = memory_agent_service.get_messages_list(user_input) +# result = await memory_agent_service.write_memory( +# user_input.end_user_id, +# messages_list, +# config_id, +# db, +# storage_type, +# user_rag_memory_id, +# language +# ) +# +# return success(data=result, msg="写入成功") +# except BaseException as e: +# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup +# if hasattr(e, 'exceptions'): +# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] +# detailed_error = "; ".join(error_messages) +# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) +# api_logger.error(f"Write operation error: {str(e)}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# +# +# @router.post("/writer_service_async", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server_async( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Async write service endpoint - enqueues write processing to Celery +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Task ID for tracking async operation +# Use GET /memory/write_result/{task_id} to check task status and get result +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info( +# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# if workspace_id: +# +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: user_rag_memory_id = str(knowledge.id) +# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") +# try: +# # 获取标准化的消息列表 +# messages_list = memory_agent_service.get_messages_list(user_input) +# +# task = celery_app.send_task( +# "app.core.memory.agent.write_message", +# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] +# ) +# api_logger.info(f"Write task queued: {task.id}") +# +# return success(data={"task_id": task.id}, msg="写入任务已提交") +# except Exception as e: +# api_logger.error(f"Async write operation failed: {str(e)}") +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) @router.post("/read_service", response_model=ApiResponse) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index cc0efab3..fe4337d1 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -663,9 +663,12 @@ async def dashboard_data( rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 - from app.repositories import app_repository - apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - rag_data["total_app"] = len(apps_orm) + # 包含自有app + 被分享给本工作空间的app + from app.services import app_service as _app_svc + _, total_app = _app_svc.AppService(db).list_apps( + workspace_id=workspace_id, include_shared=True, pagesize=1 + ) + rag_data["total_app"] = total_app # total_knowledge: 使用 total_kb(总知识库数) total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) @@ -687,7 +690,7 @@ async def dashboard_data( api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}") rag_data["total_api_call"] = 0 - api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") + api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") except Exception as e: api_logger.warning(f"获取RAG相关数据失败: {str(e)}") diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d91dfc36..d8b39325 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -54,8 +54,8 @@ router = APIRouter( @router.get("/info", response_model=ApiResponse) async def get_storage_info( - storage_id: str, - current_user: User = Depends(get_current_user) + storage_id: str, + current_user: User = Depends(get_current_user) ): """ Example wrapper endpoint - retrieves storage information @@ -75,24 +75,19 @@ async def get_storage_info( return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) - - - - - -@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 def create_config( - payload: ConfigParamsCreate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), + payload: ConfigParamsCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") try: # 将 workspace_id 注入到 payload 中(保持为 UUID 类型) @@ -107,9 +102,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {err_str}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) @@ -119,9 +116,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) @@ -129,10 +128,10 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( - config_id: UUID|int, - force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """删除记忆配置(带终端用户保护) @@ -145,24 +144,24 @@ def delete_config( force: 设置为 true 可强制删除(即使有终端用户正在使用) """ workspace_id = current_user.current_workspace_id - config_id=resolve_config_id(config_id, db) + config_id = resolve_config_id(config_id, db) # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"config_id={config_id}, force={force}" ) - + try: # 使用带保护的删除服务 from app.services.memory_config_service import MemoryConfigService - + config_service = MemoryConfigService(db) result = config_service.delete_config(config_id=config_id, force=force) - + if result["status"] == "error": api_logger.warning( f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" @@ -172,7 +171,7 @@ def delete_config( msg=result["message"], data={"config_id": str(config_id), "is_default": result.get("is_default", False)} ) - + if result["status"] == "warning": api_logger.warning( f"记忆配置正在使用,无法删除: config_id={config_id}, " @@ -186,7 +185,7 @@ def delete_config( "force_required": result["force_required"] } ) - + api_logger.info( f"记忆配置删除成功: config_id={config_id}, " f"affected_users={result['affected_users']}" @@ -195,7 +194,7 @@ def delete_config( msg=result["message"], data={"affected_users": result["affected_users"]} ) - + except Exception as e: api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) @@ -203,9 +202,9 @@ def delete_config( @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc def update_config( - payload: ConfigUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -213,12 +212,13 @@ def update_config( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + # 校验至少有一个字段需要更新 if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") - return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") - + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", + "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -231,9 +231,9 @@ def update_config( @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 def update_config_extracted( - payload: ConfigUpdateExtracted, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdateExtracted, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -241,7 +241,7 @@ def update_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -256,11 +256,11 @@ def update_config_extracted( # 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config -@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 +@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( - config_id: UUID | int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id config_id = resolve_config_id(config_id, db) @@ -268,7 +268,7 @@ def read_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") try: svc = DataConfigService(db) @@ -278,18 +278,19 @@ def read_config_extracted( api_logger.error(f"Read config extracted failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) -@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 + +@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 def read_all_config( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") try: svc = DataConfigService(db) @@ -303,14 +304,14 @@ def read_all_config( @router.post("/pilot_run", response_model=None) async def pilot_run( - payload: ConfigPilotRun, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigPilotRun, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> StreamingResponse: # 使用集中化的语言校验 language = get_language_from_header(language_type) - + api_logger.info( f"Pilot run requested: config_id={payload.config_id}, " f"dialogue_text_length={len(payload.dialogue_text)}, " @@ -333,9 +334,9 @@ async def pilot_run( @router.get("/search/kb_type_distribution", response_model=ApiResponse) async def get_kb_type_distribution( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") try: result = await kb_type_distribution(end_user_id) @@ -344,12 +345,12 @@ async def get_kb_type_distribution( api_logger.error(f"KB type distribution failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) - + @router.get("/search/dialogue", response_model=ApiResponse) async def search_dialogues_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") try: result = await search_dialogue(end_user_id) @@ -361,9 +362,9 @@ async def search_dialogues_num( @router.get("/search/chunk", response_model=ApiResponse) async def search_chunks_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") try: result = await search_chunk(end_user_id) @@ -375,9 +376,9 @@ async def search_chunks_num( @router.get("/search/statement", response_model=ApiResponse) async def search_statements_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") try: result = await search_statement(end_user_id) @@ -389,9 +390,9 @@ async def search_statements_num( @router.get("/search/entity", response_model=ApiResponse) async def search_entities_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") try: result = await search_entity(end_user_id) @@ -403,9 +404,9 @@ async def search_entities_num( @router.get("/search", response_model=ApiResponse) async def search_all_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: result = await search_all(end_user_id) @@ -417,9 +418,9 @@ async def search_all_num( @router.get("/search/detials", response_model=ApiResponse) async def search_entities_detials( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search details requested for end_user_id: {end_user_id}") try: result = await search_detials(end_user_id) @@ -431,9 +432,9 @@ async def search_entities_detials( @router.get("/search/edges", response_model=ApiResponse) async def search_entity_edges( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") try: result = await search_edges(end_user_id) @@ -443,14 +444,12 @@ async def search_entity_edges( return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) - - @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) async def get_hot_memory_tags_api( - limit: int = 10, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - ) -> dict: + limit: int = 10, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> dict: """ 获取热门记忆标签(带Redis缓存) @@ -461,18 +460,18 @@ async def get_hot_memory_tags_api( - 缓存未命中:~600-800ms(取决于LLM速度) """ workspace_id = current_user.current_workspace_id - + # 构建缓存键 cache_key = f"hot_memory_tags:{workspace_id}:{limit}" - + api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") - + try: # 尝试从Redis缓存获取 import json from app.aioRedis import aio_redis_get, aio_redis_set - + cached_result = await aio_redis_get(cache_key) if cached_result: api_logger.info(f"Cache hit for key: {cache_key}") @@ -481,11 +480,11 @@ async def get_hot_memory_tags_api( return success(data=data, msg="查询成功(缓存)") except json.JSONDecodeError: api_logger.warning(f"Failed to parse cached data, will refresh") - + # 缓存未命中,执行查询 api_logger.info(f"Cache miss for key: {cache_key}, executing query") result = await analytics_hot_memory_tags(db, current_user, limit) - + # 写入缓存(过期时间:5分钟) # 注意:result是列表,需要转换为JSON字符串 try: @@ -495,9 +494,9 @@ async def get_hot_memory_tags_api( except Exception as cache_error: # 缓存写入失败不影响主流程 api_logger.warning(f"Failed to cache result: {str(cache_error)}") - + return success(data=result, msg="查询成功") - + except Exception as e: api_logger.error(f"Hot memory tags failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) @@ -505,8 +504,8 @@ async def get_hot_memory_tags_api( @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) async def clear_hot_memory_tags_cache( - current_user: User = Depends(get_current_user), - ) -> dict: + current_user: User = Depends(get_current_user), +) -> dict: """ 清除热门标签缓存 @@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache( - 数据更新后立即生效 """ workspace_id = current_user.current_workspace_id - + api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") - + try: from app.aioRedis import aio_redis_delete - + # 清除所有limit的缓存(常见的limit值) cleared_count = 0 for limit in [5, 10, 15, 20, 30, 50]: @@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache( if result: cleared_count += 1 api_logger.info(f"Cleared cache for key: {cache_key}") - + return success( - data={"cleared_count": cleared_count}, + data={"cleared_count": cleared_count}, msg=f"成功清除 {cleared_count} 个缓存" ) - + except Exception as e: api_logger.error(f"Clear cache failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) @@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( - current_user: User = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> dict: workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") @@ -553,4 +552,3 @@ async def get_recent_activity_stats_api( except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 6204a745..71fd41ad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -42,6 +42,7 @@ def get_model_strategies(): @router.get("", response_model=ApiResponse) def get_model_list( type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -74,10 +75,21 @@ def get_model_list( unique_flat_type = list(dict.fromkeys(flat_type)) type_list = [ModelType(t.lower()) for t in unique_flat_type] + capability_list = [] + if capability is not None: + flat_capability = [] + for item in capability: + split_items = [c.strip() for c in item.split(', ') if c.strip()] + flat_capability.extend(split_items) + + unique_flat_capability = list(dict.fromkeys(flat_capability)) + capability_list = unique_flat_capability + api_logger.error(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQuery( type=type_list, provider=provider, + capability=capability_list, is_active=is_active, is_public=is_public, search=search, diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 33d7b60c..f5284b46 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -669,6 +669,7 @@ async def config_query( content = { "app_type": release.app.type, "variables": release.config.get("variables"), + "memory": release.config.get("memory", {}).get("enabled"), "features": release.config.get("features") } elif release.app.type == AppType.MULTI_AGENT: diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index be796ff9..3ce1df6e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -5,7 +5,7 @@ from typing import Optional import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends,Header +from fastapi import APIRouter, Depends, Header from app.db import get_db from app.core.language_utils import get_language_from_header @@ -19,7 +19,7 @@ from app.services.user_memory_service import ( analytics_graph_data, analytics_community_graph_data, ) -from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction +from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository @@ -45,9 +45,9 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的记忆洞察报告 @@ -73,10 +73,10 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的用户摘要 @@ -90,7 +90,7 @@ async def get_user_summary_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -102,7 +102,7 @@ async def get_user_summary_api( api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) + result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -117,10 +117,10 @@ async def get_user_summary_api( @router.post("/analytics/generate_cache", response_model=ApiResponse) async def generate_cache_api( - request: GenerateCacheRequest, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + request: GenerateCacheRequest, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 手动触发缓存生成 @@ -134,7 +134,7 @@ async def generate_cache_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -155,10 +155,12 @@ async def generate_cache_api( api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") # 生成记忆洞察 - insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) + insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, + language=language) # 生成用户摘要 - summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) + summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, + language=language) # 构建响应 result = { @@ -209,9 +211,9 @@ async def generate_cache_api( @router.get("/analytics/node_statistics", response_model=ApiResponse) async def get_node_statistics_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -220,7 +222,8 @@ async def get_node_statistics_api( api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + api_logger.info( + f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") try: # 调用新的记忆类型统计函数 @@ -228,21 +231,23 @@ async def get_node_statistics_api( # 计算总数用于日志 total_count = sum(item["count"] for item in result) - api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") + api_logger.info( + f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) + @router.get("/analytics/graph_data", response_model=ApiResponse) async def get_graph_data_api( - end_user_id: str, - node_types: Optional[str] = None, - limit: int = 100, - depth: int = 1, - center_node_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -298,9 +303,9 @@ async def get_graph_data_api( @router.get("/analytics/community_graph", response_model=ApiResponse) async def get_community_graph_data_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -334,9 +339,9 @@ async def get_community_graph_data_api( @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) @@ -385,9 +390,9 @@ async def get_end_user_profile( @router.post("/updated_end_user/profile", response_model=ApiResponse) async def update_end_user_profile( - profile_update: EndUserProfileUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 更新终端用户的基本信息 @@ -417,7 +422,7 @@ async def update_end_user_profile( else: error_msg = result["error"] api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") - + # 根据错误类型映射到合适的业务错误码 if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) @@ -427,15 +432,18 @@ async def update_end_user_profile( # 只有未预期的错误才使用 INTERNAL_ERROR return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) + @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): +async def memory_space_timeline_of_shared_memories( + id: str, label: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): # 使用集中化的语言校验 language = get_language_from_header(language_type) - - workspace_id=current_user.current_workspace_id + + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_ timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) return success(data=timeline_memories_result, msg="共同记忆时间线") + + @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): try: api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 88b6371c..464a668a 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -598,8 +598,10 @@ class LangChainAgent: for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", - 0) if response_meta else 0 + total_tokens = response_meta.get("token_usage", {}).get( + "total_tokens", + 0 + ) if response_meta else 0 yield total_tokens break if memory_flag: diff --git a/api/app/core/config.py b/api/app/core/config.py index 4a944557..64c5520e 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -231,8 +231,8 @@ class Settings: # Celery configuration (internal) # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 # 详见 docs/celery-env-bug-report.md - # 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离 - # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰 + # 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离 + # 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) diff --git a/api/app/core/logging_config.py b/api/app/core/logging_config.py index 28a98a46..d0dda84b 100644 --- a/api/app/core/logging_config.py +++ b/api/app/core/logging_config.py @@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") - # Fallback to console only if file write fails print(f"Warning: Could not write to timing log: {e}") - # Always print to console (backward compatible behavior) - print(f"✓ {step_name}: {duration:.2f}s") + # Always log at INFO level (avoids Celery treating stdout as WARNING) + _timing_logger = logging.getLogger(__name__) + _timing_logger.info(f"✓ {step_name}: {duration:.2f}s") def get_agent_logger(name: str = "agent_service", diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 6176caf5..2074b6ca 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): logger.info('写入长期记忆NEO4J') - formatted_messages = (redis_messages) + formatted_messages = redis_messages # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 3b06defe..4c667061 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -11,7 +11,7 @@ async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", end_user_id: str = "group_1", messages: list = None, - ref_id: str = "wyl_20251027", + ref_id: str = "", config_id: str = None ) -> List[DialogData]: """Generate chunks from structured messages using the specified chunker strategy. @@ -40,12 +40,13 @@ async def get_chunked_dialogs( role = msg['role'] content = msg['content'] + files = msg.get("file_content", []) if role not in ['user', 'assistant']: raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") if content.strip(): - conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) if not conversation_messages: raise ValueError("Message list cannot be empty after filtering") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..6829cf57 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio import time +import uuid from datetime import datetime from dotenv import load_dotenv @@ -13,28 +14,28 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ + memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig - load_dotenv() logger = get_agent_logger(__name__) async def write( - end_user_id: str, - memory_config: MemoryConfig, - messages: list, - ref_id: str = "wyl20251027", - language: str = "zh", + end_user_id: str, + memory_config: MemoryConfig, + messages: list, + ref_id: str = "", + language: str = "zh", ) -> None: """ Execute the complete knowledge extraction pipeline. @@ -43,9 +44,11 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - ref_id: Reference ID, defaults to "wyl20251027" + ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ + if not ref_id: + ref_id = uuid.uuid4().hex # Extract config values embedding_model_id = str(memory_config.embedding_model_id) chunker_strategy = memory_config.chunker_strategy @@ -99,14 +102,14 @@ async def write( if memory_config.scene_id: try: from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene - + with get_db_context() as db: ontology_types = load_ontology_types_for_scene( scene_id=memory_config.scene_id, workspace_id=memory_config.workspace_id, db=db ) - + if ontology_types: logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" @@ -135,9 +138,11 @@ async def write( all_chunk_nodes, all_statement_nodes, all_entity_nodes, + all_perceptual_nodes, all_statement_chunk_edges, all_statement_entity_edges, all_entity_entity_edges, + all_perceptual_edges, all_dedup_details, ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) @@ -162,18 +167,21 @@ async def write( chunk_nodes=all_chunk_nodes, statement_nodes=all_statement_nodes, entity_nodes=all_entity_nodes, + perceptual_nodes=all_perceptual_nodes, statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, + perceptual_edges=all_perceptual_edges, connector=neo4j_connector, ) if success: logger.info("Successfully saved all data to Neo4j") - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write( + # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) + await _trigger_clustering_sync( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + embedding_model_id=str( + memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) break else: @@ -208,9 +216,8 @@ async def write( summaries = await memory_summary_generation( chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language ) - + ms_connector = Neo4jConnector() try: - ms_connector = Neo4jConnector() await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector) finally: @@ -251,4 +258,4 @@ async def write( logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file + logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 93a2df82..51d15aab 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -1,10 +1,10 @@ -from typing import Any, List -import re -import os import asyncio import json -import numpy as np import logging +import os +from typing import Any, List + +import numpy as np # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -246,6 +246,7 @@ class ChunkerClient: "total_sub_chunks": len(sub_chunks), "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) else: @@ -258,6 +259,7 @@ class ChunkerClient: "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) diff --git a/api/app/core/memory/llm_tools/openai_embedder.py b/api/app/core/memory/llm_tools/openai_embedder.py index 2d6fccbc..6ae87887 100644 --- a/api/app/core/memory/llm_tools/openai_embedder.py +++ b/api/app/core/memory/llm_tools/openai_embedder.py @@ -2,6 +2,7 @@ OpenAI Embedder 客户端实现 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 +自动支持火山引擎的多模态 Embedding。 """ from typing import List @@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import ( ) from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings +from app.models.models_model import ModelProvider logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient): - 批量文本嵌入 - 自动重试机制 - 错误处理 + - 火山引擎多模态 Embedding(自动识别) """ def __init__(self, model_config: RedBearModelConfig): @@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient): """ super().__init__(model_config) - # 初始化 RedBearEmbeddings 模型 + # 初始化 RedBearEmbeddings(自动支持火山引擎多模态) self.model = RedBearEmbeddings( RedBearModelConfig( model_name=self.model_name, @@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient): timeout=self.timeout, ) ) + self.is_multimodal = self.model.is_multimodal_supported() - logger.info("OpenAI Embedder 客户端初始化完成") + logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})") async def response( self, @@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient): return [] # 生成嵌入向量 - embeddings = await self.model.aembed_documents(texts) + if self.is_multimodal: + # 火山引擎多模态 Embedding + embeddings = await self.model.aembed_multimodal( + [{"type": "text", "text": text} for text in texts] + ) + else: + # 普通 Embedding + embeddings = await self.model.aembed_documents(texts) logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") return embeddings diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1880b9ab..1b8c9d52 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -44,21 +44,21 @@ def parse_historical_datetime(v): """ if v is None: return v - + # 处理 Neo4j DateTime 对象 if hasattr(v, 'to_native'): return v.to_native() - + # 处理 Python datetime 对象 if isinstance(v, datetime): return v - + if isinstance(v, str): # 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 支持1-4位年份 pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' match = re.match(pattern, v) - + if match: try: year = int(match.group(1)) @@ -68,31 +68,31 @@ def parse_historical_datetime(v): minute = int(match.group(5)) if match.group(5) else 0 second = int(match.group(6)) if match.group(6) else 0 microsecond = 0 - + # 处理微秒 if match.group(7): # 补齐或截断到6位 us_str = match.group(7).ljust(6, '0')[:6] microsecond = int(us_str) - + # 处理时区 tzinfo = None if 'Z' in v or match.group(8): tzinfo = timezone.utc - + # 创建 datetime 对象 return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) - + except (ValueError, OverflowError): # 日期值无效(如月份13、日期32等) return None - + # 如果不匹配模式,尝试使用 fromisoformat(用于标准格式) try: return datetime.fromisoformat(v.replace('Z', '+00:00')) except Exception: return None - + return v @@ -114,7 +114,7 @@ class Edge(BaseModel): end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") + expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.") class ChunkEdge(Edge): @@ -167,7 +167,7 @@ class EntityEntityEdge(Edge): source_statement_id: str = Field(..., description="Statement where this relationship was extracted") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): @@ -175,6 +175,12 @@ class EntityEntityEdge(Edge): return parse_historical_datetime(v) +class PerceptualEdge(Edge): + """Edge connecting perceptual nodes to their source chunks + """ + pass + + class Node(BaseModel): """Base class for all graph nodes in the knowledge graph. @@ -206,7 +212,8 @@ class DialogueNode(Node): ref_id: str = Field(..., description="Reference identifier of the dialog") content: str = Field(..., description="Dialogue content") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this dialogue (integer or string)") class StatementNode(Node): @@ -241,17 +248,17 @@ class StatementNode(Node): chunk_id: str = Field(..., description="ID of the parent chunk") stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") - + # Speaker identification speaker: Optional[str] = Field( None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" ) - + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( - None, - ge=0.0, + None, + ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0 (displayed on node)" ) @@ -264,25 +271,26 @@ class StatementNode(Node): description="Emotion subject: self/other/object" ) emotion_type: Optional[str] = Field( - None, + None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral" ) emotion_keywords: Optional[List[str]] = Field( default_factory=list, description="Emotion keywords list, max 3 items" ) - + # Temporal fields temporal_info: TemporalInfo = Field(..., description="Temporal information") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + # Embedding and other fields statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this statement (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -309,13 +317,13 @@ class StatementNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): """使用通用的历史日期解析函数""" return parse_historical_datetime(v) - + @field_validator('emotion_type', mode='before') @classmethod def validate_emotion_type(cls, v): @@ -326,7 +334,7 @@ class StatementNode(Node): if v not in valid_types: raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") return v - + @field_validator('emotion_subject', mode='before') @classmethod def validate_emotion_subject(cls, v): @@ -337,7 +345,7 @@ class StatementNode(Node): if v not in valid_subjects: raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") return v - + @field_validator('emotion_keywords', mode='before') @classmethod def validate_emotion_keywords(cls, v): @@ -405,19 +413,20 @@ class ExtractedEntityNode(Node): entity_type: str = Field(..., description="Type of the entity") description: str = Field(..., description="Entity description") example: str = Field( - default="", + default="", description="A concise example (around 20 characters) to help understand the entity" ) aliases: List[str] = Field( - default_factory=list, + default_factory=list, description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this entity (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -444,16 +453,16 @@ class ExtractedEntityNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + # Explicit Memory Classification is_explicit_memory: bool = Field( default=False, description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" ) - + @field_validator('aliases', mode='before') @classmethod - def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 + def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 """Validate and clean aliases field using utility function. This validator ensures that the aliases field is always a valid list of strings. @@ -507,8 +516,9 @@ class MemorySummaryNode(Node): memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this summary (integer or string)") + # ACT-R Forgetting Engine Properties original_statement_id: Optional[str] = Field( None, @@ -522,7 +532,7 @@ class MemorySummaryNode(Node): None, description="Timestamp when the nodes were merged" ) - + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -549,3 +559,18 @@ class MemorySummaryNode(Node): ge=0, description="Total number of times this node has been accessed (reset to 1 on creation)" ) + + +class PerceptualNode(Node): + """Node representing a multimodal message in the knowledge graph. + """ + perceptual_type: int + file_path: str + file_name: str + file_ext: str + summary: str + keywords: list[str] + topic: str + domain: str + file_type: str + summary_embedding: list[float] | None diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 2f8660af..66203067 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -30,6 +30,7 @@ class ConversationMessage(BaseModel): """ role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") + files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True) class TemporalValidityRange(BaseModel): @@ -130,7 +131,8 @@ class Chunk(BaseModel): content: str = Field(..., description="The content of the chunk as a string.") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") - chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") + files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") + chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index d9c04f8b..0fa6a833 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -71,13 +71,11 @@ class LabelPropagationEngine: connector: Neo4jConnector, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id - self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -239,6 +237,7 @@ class LabelPropagationEngine: await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") + await self._generate_community_metadata([new_cid], end_user_id) return # 统计邻居社区分布 @@ -273,7 +272,8 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - await self._generate_community_metadata([target_cid], end_user_id) + # 新实体加入后成员变化,强制重新生成元数据 + await self._generate_community_metadata([target_cid], end_user_id, force=True) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -453,7 +453,7 @@ class LabelPropagationEngine: return lines async def _generate_community_metadata( - self, community_ids: List[str], end_user_id: str + self, community_ids: List[str], end_user_id: str, force: bool = False ) -> None: """ 为一个或多个社区生成并写入元数据。 @@ -462,69 +462,82 @@ class LabelPropagationEngine: 1. 逐个社区调 LLM 生成 name / summary(串行) 2. 收集所有 summary,一次性批量 embed 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata - """ - if not community_ids: - return + Args: + force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后) + """ from app.db import get_db_context from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - # --- 阶段1:并发调 LLM 生成每个社区的 name / summary --- - async def _build_one(cid: str): - members = await self.repo.get_community_members(cid, end_user_id) - if not members: + async def _build_one(cid: str) -> Optional[Dict]: + try: + if not force: + check_embedding = bool(self.embedding_model_id) + if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding): + return None + + members = await self.repo.get_community_members(cid, end_user_id) + if not members: + logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成") + return None + + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else cid[:8] + summary = f"包含实体:{', '.join(all_names)}" + + if self.llm_model_id: + try: + entity_list_str = "\n".join(self._build_entity_lines(members)) + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过80个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}") + + return { + "community_id": cid, + "end_user_id": end_user_id, + "name": name, + "summary": summary, + "core_entities": core_entities, + "summary_embedding": None, + } + except Exception as e: + logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True) return None - sorted_members = sorted( - members, - key=lambda m: m.get("activation_value") or 0, - reverse=True, - ) - core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] - - entity_list_str = "\n".join(self._build_entity_lines(members)) - - # 方案四:注入社区内实体间关系三元组 - relationships = await self.repo.get_community_relationships(cid, end_user_id) - rel_lines = [ - f"- {r['subject']} → {r['predicate']} → {r['object']}" - for r in relationships - if r.get("subject") and r.get("predicate") and r.get("object") - ] - rel_section = ( - f"\n实体间关系:\n" + "\n".join(rel_lines) - if rel_lines else "" - ) - - prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过80个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) - - name, summary = "", "" - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - - return { - "community_id": cid, - "end_user_id": end_user_id, - "name": name, - "summary": summary, - "core_entities": core_entities, - "summary_embedding": None, - } - results = await asyncio.gather( *[_build_one(cid) for cid in community_ids], return_exceptions=True, @@ -537,15 +550,20 @@ class LabelPropagationEngine: metadata_list.append(res) if not metadata_list: + logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}") return # --- 阶段2:批量生成 summary_embedding --- - summaries = [m["summary"] for m in metadata_list] - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - embeddings = await embedder.response(summaries) - for i, meta in enumerate(metadata_list): - meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + if self.embedding_model_id: + try: + summaries = [m["summary"] for m in metadata_list] + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + except Exception as e: + logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) # --- 阶段3:写入(单个 or 批量)--- if len(metadata_list) == 1: @@ -558,17 +576,13 @@ class LabelPropagationEngine: core_entities=m["core_entities"], summary_embedding=m["summary_embedding"], ) - if result: - logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") - else: - logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False") + if not result: + logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败") else: ok = await self.repo.batch_update_community_metadata(metadata_list) - if ok: - logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") - else: - logger.warning(f"[Clustering] 批量写入社区元数据失败") + if not ok: + logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败") @staticmethod def _new_community_id() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4()) \ No newline at end of file diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 248067e7..967f529e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -9,6 +9,7 @@ """ import asyncio +import logging import os import hashlib import json @@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene ScenePatterns ) +logger = logging.getLogger(__name__) + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -706,7 +709,7 @@ class SemanticPruner: # 阈值保护:最高0.9 proportion = float(self.config.pruning_threshold) if proportion > 0.9: - print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") + logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") proportion = 0.9 if proportion < 0.0: proportion = 0.0 @@ -905,7 +908,7 @@ class SemanticPruner: # Safety: avoid empty dataset if not result: - print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") + logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs return result @@ -915,8 +918,7 @@ class SemanticPruner: try: self.run_logs.append(msg) except Exception: - # 任何异常都不影响打印 pass - print(msg) + logger.debug(msg) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index f28b8a5f..4b9c5718 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def dedup_layers_and_merge_and_return( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - pipeline_config: ExtractionPipelineConfig, - connector: Optional[Neo4jConnector] = None, - llm_client = None, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + pipeline_config: ExtractionPipelineConfig, + connector: Optional[Neo4jConnector] = None, + llm_client=None, ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return( List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], - dict, # 新增:返回去重详情 + dict ]: """ 执行两层实体去重与融合: diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 00d06f72..e0b86d8c 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import ( StatementChunkEdge, StatementEntityEdge, StatementNode, + PerceptualEdge, + PerceptualNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList -from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, ) @@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb embedding_generation, generate_entity_embeddings_from_triplets, ) - # 导入各个提取模块 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( StatementExtractor, @@ -90,16 +90,16 @@ class ExtractionOrchestrator: """ def __init__( - self, - llm_client: LLMClient, - embedder_client: OpenAIEmbedderClient, - connector: Neo4jConnector, - config: Optional[ExtractionPipelineConfig] = None, - progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, - embedding_id: Optional[str] = None, - ontology_types: Optional[OntologyTypeList] = None, - enable_general_types: bool = True, - language: str = "zh", + self, + llm_client: LLMClient, + embedder_client: OpenAIEmbedderClient, + connector: Neo4jConnector, + config: Optional[ExtractionPipelineConfig] = None, + progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, + embedding_id: Optional[str] = None, + ontology_types: Optional[OntologyTypeList] = None, + enable_general_types: bool = True, + language: str = "zh", ): """ 初始化流水线编排器 @@ -123,7 +123,7 @@ class ExtractionOrchestrator: self.progress_callback = progress_callback # 保存进度回调函数 self.embedding_id = embedding_id # 保存嵌入模型ID self.language = language # 保存语言配置 - + # 处理本体类型配置 # 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并 # 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合 @@ -146,7 +146,7 @@ class ExtractionOrchestrator: self.ontology_types = ontology_types if not enable_general_types and ontology_types: logger.info("enable_general_types=False,仅使用场景类型") - + # 保存去重消歧的详细记录(内存中的数据结构) self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录 self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录 @@ -157,19 +157,27 @@ class ExtractionOrchestrator: llm_client=llm_client, config=self.config.statement_extraction, ) - self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) + self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types, + language=language) self.temporal_extractor = TemporalExtractor(llm_client=llm_client) logger.info("ExtractionOrchestrator 初始化完成") async def run( - self, - dialog_data_list: List[DialogData], - is_pilot_run: bool = False, - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialog_data_list: List[DialogData], + is_pilot_run: bool = False, + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[PerceptualNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + list[PerceptualEdge], + dict ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -202,13 +210,12 @@ class ExtractionOrchestrator: # 步骤 1: 陈述句提取 logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") dialog_data_list = await self._extract_statements(dialog_data_list) - + # 收集陈述句内容和统计数量 all_statements_list = [] for dialog in dialog_data_list: for chunk in dialog.chunks: all_statements_list.extend(chunk.statements) - len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") @@ -220,7 +227,7 @@ class ExtractionOrchestrator: chunk_embedding_maps, dialog_embeddings, ) = await self._parallel_extract_and_embed(dialog_data_list) - + # 收集实体和三元组内容,并统计数量 all_entities_list = [] all_triplets_list = [] @@ -229,10 +236,6 @@ class ExtractionOrchestrator: if triplet_info: all_entities_list.extend(triplet_info.entities) all_triplets_list.extend(triplet_info.triplets) - - len(all_entities_list) - len(all_triplets_list) - sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") @@ -252,17 +255,19 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") - + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 - + ( dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) = await self._create_nodes_and_edges(dialog_data_list) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) @@ -273,10 +278,19 @@ class ExtractionOrchestrator: logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)") else: logger.info("步骤 6/6: 两阶段去重和消歧") - + # 注意:deduplication 消息已在创建节点和边完成后立即发送 - - result = await self._run_dedup_and_write_summary( + + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + dialog_data_list, + ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, statement_nodes, @@ -287,17 +301,26 @@ class ExtractionOrchestrator: dialog_data_list, ) - - logger.info(f"知识提取流水线运行完成({mode_str})") - return result + return ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + perceptual_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + perceptual_edges, + dialog_data_list, + ) except Exception as e: logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) raise async def _extract_statements( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ 从对话中提取陈述句(流式输出版本:边提取边发送进度) @@ -313,7 +336,7 @@ class ExtractionOrchestrator: # 收集所有分块及其元数据 all_chunks = [] chunk_metadata = [] # (dialog_idx, chunk_idx) - + for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): @@ -321,7 +344,7 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") - + # 用于跟踪已完成的分块数量 completed_chunks = 0 total_chunks = len(all_chunks) @@ -332,7 +355,7 @@ class ExtractionOrchestrator: chunk, end_user_id, dialogue_content = chunk_data try: statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) - + # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 completed_chunks += 1 @@ -347,11 +370,11 @@ class ExtractionOrchestrator: "statement_index_in_chunk": idx + 1 } await self.progress_callback( - "knowledge_extraction_result", - f"陈述句提取中 ({completed_chunks}/{total_chunks})", + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", stmt_result ) - + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") @@ -381,7 +404,7 @@ class ExtractionOrchestrator: # 保存陈述句到文件(试运行和正式模式都需要) self.statement_extractor.save_statements(all_statements) - + logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") # 试运行模式下,所有分块提取完成后发送完成事件 @@ -395,7 +418,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _extract_triplets( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取三元组(流式输出版本:边提取边发送进度) @@ -411,7 +434,7 @@ class ExtractionOrchestrator: # 收集所有陈述句及其元数据 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, chunk_content) - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -419,7 +442,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") - + # 用于跟踪已完成的陈述句数量 completed_statements = 0 len(all_statements) @@ -430,11 +453,11 @@ class ExtractionOrchestrator: statement, chunk_content = stmt_data try: triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) - + # 注意:不再发送三元组提取的流式输出 # 三元组提取在后台执行,但不向前端发送详细信息 completed_statements += 1 - + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") @@ -450,7 +473,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 triplet_maps = [{} for _ in dialog_data_list] all_responses = [] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -478,7 +501,7 @@ class ExtractionOrchestrator: return triplet_maps async def _extract_temporal( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取时间信息(流式输出版本:边提取边发送进度) @@ -502,13 +525,13 @@ class ExtractionOrchestrator: temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) temporal_maps.append(temporal_map) return temporal_maps - + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, ref_dates) - + for d_idx, dialog in enumerate(dialog_data_list): # 获取参考日期 ref_dates = {} @@ -517,11 +540,11 @@ class ExtractionOrchestrator: ref_dates['conversation_date'] = dialog.metadata['conversation_date'] if 'publication_date' in dialog.metadata: ref_dates['publication_date'] = dialog.metadata['publication_date'] - + if not ref_dates: from datetime import datetime ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - + for chunk in dialog.chunks: for statement in chunk.statements: # 跳过 ATEMPORAL 类型的陈述句 @@ -531,7 +554,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") - + # 用于跟踪已完成的时间提取数量 completed_temporal = 0 len(all_statements) @@ -542,11 +565,11 @@ class ExtractionOrchestrator: statement, ref_dates = stmt_data try: temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) - + # 注意:不再发送时间提取的流式输出 # 时间提取在后台执行,但不向前端发送详细信息 completed_temporal += 1 - + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") @@ -559,7 +582,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 temporal_maps = [{} for _ in dialog_data_list] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -585,7 +608,7 @@ class ExtractionOrchestrator: return temporal_maps async def _extract_emotions( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) @@ -601,36 +624,36 @@ class ExtractionOrchestrator: # 收集所有陈述句及其配置 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id) - + # 获取第一个对话的config_id来加载配置 config_id = None if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): config_id = dialog_data_list[0].config_id - + # 加载MemoryConfig memory_config = None if config_id: try: from app.db import SessionLocal from app.repositories.memory_config_repository import MemoryConfigRepository - + db = SessionLocal() try: memory_config = MemoryConfigRepository.get_by_id(db, config_id) finally: db.close() - + if memory_config and not memory_config.emotion_enabled: logger.info("情绪提取已在配置中禁用,跳过情绪提取") return [{} for _ in dialog_data_list] - + except Exception as e: logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取") return [{} for _ in dialog_data_list] else: logger.info("未找到config_id,跳过情绪提取") return [{} for _ in dialog_data_list] - + # 如果配置未启用情绪提取,直接返回空映射 if not memory_config or not memory_config.emotion_enabled: logger.info("情绪提取未启用,跳过") @@ -639,7 +662,7 @@ class ExtractionOrchestrator: # 收集所有陈述句(只收集 speaker 为 "user" 的) total_statements = 0 filtered_statements = 0 - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -655,12 +678,12 @@ class ExtractionOrchestrator: # 初始化情绪提取服务 # 如果 emotion_model_id 为空,回退到工作空间默认 LLM from app.services.emotion_extraction_service import EmotionExtractionService - + emotion_model_id = memory_config.emotion_model_id if not emotion_model_id and memory_config.workspace_id: from app.repositories.workspace_repository import get_workspace_models_configs from app.db import SessionLocal - + db = SessionLocal() try: workspace_models = get_workspace_models_configs(db, memory_config.workspace_id) @@ -669,7 +692,7 @@ class ExtractionOrchestrator: logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}") finally: db.close() - + emotion_service = EmotionExtractionService( llm_id=emotion_model_id if emotion_model_id else None ) @@ -689,7 +712,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 emotion_maps = [{} for _ in dialog_data_list] successful_extractions = 0 - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -706,7 +729,7 @@ class ExtractionOrchestrator: return emotion_maps async def _parallel_extract_and_embed( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[Dict[str, Any]], List[Dict[str, Any]], @@ -757,7 +780,7 @@ class ExtractionOrchestrator: triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - + if isinstance(results[3], Exception): logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] @@ -777,7 +800,7 @@ class ExtractionOrchestrator: ) async def _generate_basic_embeddings( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: """ 生成基础嵌入向量(陈述句、分块、对话) @@ -810,7 +833,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") raise ValueError("embedding_id is required but was not provided") - + # 只生成陈述句、分块和对话的嵌入(不包括实体) statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation( dialog_data_list, self.embedding_id @@ -836,7 +859,7 @@ class ExtractionOrchestrator: ) async def _generate_entity_embeddings( - self, triplet_maps: List[Dict[str, Any]] + self, triplet_maps: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 生成实体嵌入向量 @@ -861,7 +884,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") return triplet_maps - + # 生成实体嵌入 updated_triplet_maps = await generate_entity_embeddings_from_triplets( triplet_maps, self.embedding_id @@ -874,17 +897,15 @@ class ExtractionOrchestrator: logger.error(f"实体嵌入生成失败: {e}", exc_info=True) return triplet_maps - - async def _assign_extracted_data( - self, - dialog_data_list: List[DialogData], - temporal_maps: List[Dict[str, Any]], - triplet_maps: List[Dict[str, Any]], - emotion_maps: List[Dict[str, Any]], - statement_embedding_maps: List[Dict[str, List[float]]], - chunk_embedding_maps: List[Dict[str, List[float]]], - dialog_embeddings: List[List[float]], + self, + dialog_data_list: List[DialogData], + temporal_maps: List[Dict[str, Any]], + triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], + statement_embedding_maps: List[Dict[str, List[float]]], + chunk_embedding_maps: List[Dict[str, List[float]]], + dialog_embeddings: List[List[float]], ) -> List[DialogData]: """ 将提取的数据赋值到语句 @@ -906,12 +927,12 @@ class ExtractionOrchestrator: # 确保列表长度匹配 expected_length = len(dialog_data_list) if ( - len(temporal_maps) != expected_length - or len(triplet_maps) != expected_length - or len(emotion_maps) != expected_length - or len(statement_embedding_maps) != expected_length - or len(chunk_embedding_maps) != expected_length - or len(dialog_embeddings) != expected_length + len(temporal_maps) != expected_length + or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length + or len(statement_embedding_maps) != expected_length + or len(chunk_embedding_maps) != expected_length + or len(dialog_embeddings) != expected_length ): logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " @@ -999,15 +1020,17 @@ class ExtractionOrchestrator: return dialog_data_list async def _create_nodes_and_edges( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[DialogueNode], List[ChunkNode], List[StatementNode], List[ExtractedEntityNode], + List[PerceptualNode], List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], + List[PerceptualEdge] ]: """ 创建图数据库节点和边 @@ -1021,7 +1044,7 @@ class ExtractionOrchestrator: 包含所有节点和边的元组 """ logger.info("开始创建节点和边") - + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] @@ -1031,10 +1054,12 @@ class ExtractionOrchestrator: statement_chunk_edges = [] statement_entity_edges = [] entity_entity_edges = [] + perceptual_nodes = [] + perceptual_edges = [] # 用于去重的集合 entity_id_set = set() - + # 用于跟踪进度 total_dialogs = len(dialog_data_list) processed_dialogs = 0 @@ -1075,6 +1100,45 @@ class ExtractionOrchestrator: ) chunk_nodes.append(chunk_node) + for p, file_type in chunk.files: + + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if self.embedder_client and p.summary: + try: + summary_embedding = (await self.embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + perceptual = PerceptualNode( + name=f"Perceptual_{p.id}", + **{ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "file_type": file_type, + "summary_embedding": summary_embedding, + }) + perceptual_nodes.append(perceptual) + perceptual_edges.append(PerceptualEdge( + source=perceptual.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + )) + # 处理每个陈述句 for statement in chunk.statements: # 创建陈述句节点 @@ -1083,15 +1147,19 @@ class ExtractionOrchestrator: name=f"Statement_{statement.id}", # 添加必需的 name 字段 chunk_id=chunk.id, stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 - temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 - connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), + # 添加必需的 temporal_info 字段 + connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, - valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, + valid_at=statement.temporal_validity.valid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, + invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, @@ -1120,7 +1188,7 @@ class ExtractionOrchestrator: # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} - + # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): # 映射实体索引到实体ID(使用多个键以提高容错性) @@ -1128,7 +1196,7 @@ class ExtractionOrchestrator: entity_idx_to_id[entity.entity_idx] = entity.id # 2. 使用枚举索引(从0开始) entity_idx_to_id[entity_idx] = entity.id - + if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_node = ExtractedEntityNode( @@ -1141,7 +1209,8 @@ class ExtractionOrchestrator: example=getattr(entity, 'example', ''), # 新增:传递示例字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 @@ -1171,7 +1240,7 @@ class ExtractionOrchestrator: # 将三元组中的整数索引映射到实体ID subject_entity_id = entity_idx_to_id.get(triplet.subject_id) object_entity_id = entity_idx_to_id.get(triplet.object_id) - + # 只有当两个实体ID都存在时才创建边 if subject_entity_id and object_entity_id: entity_entity_edge = EntityEntityEdge( @@ -1186,7 +1255,7 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) - + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) if self.progress_callback and len(entity_entity_edges) <= 10: # 获取实体名称 @@ -1202,8 +1271,8 @@ class ExtractionOrchestrator: "dialog_progress": f"{processed_dialogs}/{total_dialogs}" } await self.progress_callback( - "creating_nodes_edges_result", - f"关系创建中 ({processed_dialogs}/{total_dialogs})", + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", relationship_result ) else: @@ -1211,7 +1280,7 @@ class ExtractionOrchestrator: missing_subject = "subject" if not subject_entity_id else "" missing_object = "object" if not object_entity_id else "" missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" - + logger.debug( f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " f"subject_id={triplet.subject_id} ({triplet.subject_name}), " @@ -1228,7 +1297,7 @@ class ExtractionOrchestrator: f"陈述句-实体边: {len(statement_entity_edges)}, " f"实体-实体边: {len(entity_entity_edges)}" ) - + # 进度回调:创建节点和边完成,传递结果统计 # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: @@ -1248,25 +1317,32 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) async def _run_dedup_and_write_summary( - self, - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 执行两阶段去重并写入汇总 @@ -1288,11 +1364,11 @@ class ExtractionOrchestrator: - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) """ logger.info("开始两阶段实体去重和消歧") - + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") - + logger.info( f"去重前: {len(entity_nodes)} 个实体节点, " f"{len(statement_entity_edges)} 条陈述句-实体边, " @@ -1307,7 +1383,7 @@ class ExtractionOrchestrator: from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( deduplicate_entities_and_edges, ) - + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, @@ -1317,10 +1393,10 @@ class ExtractionOrchestrator: dedup_config=self.config.deduplication, llm_client=self.llm_client, ) - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes) - + result_tuple = ( dialogue_nodes, chunk_nodes, @@ -1330,7 +1406,7 @@ class ExtractionOrchestrator: dedup_statement_entity_edges, dedup_entity_entity_edges, ) - + final_entity_nodes = dedup_entity_nodes final_statement_entity_edges = dedup_statement_entity_edges final_entity_entity_edges = dedup_entity_entity_edges @@ -1361,7 +1437,7 @@ class ExtractionOrchestrator: final_entity_entity_edges, dedup_details, ) = result_tuple - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) @@ -1375,12 +1451,12 @@ class ExtractionOrchestrator: f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { @@ -1391,10 +1467,10 @@ class ExtractionOrchestrator: "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { @@ -1407,14 +1483,13 @@ class ExtractionOrchestrator: "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) - + # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback( len(entity_nodes), len(final_entity_nodes), len(statement_entity_edges), len(final_statement_entity_edges), len(entity_entity_edges), len(final_entity_entity_edges) ) - # 写入提取结果汇总(试运行和正式模式都需要生成) try: @@ -1436,10 +1511,10 @@ class ExtractionOrchestrator: raise def _save_dedup_details( - self, - dedup_details: Dict[str, Any], - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ): """ 保存去重消歧的详细记录到实例变量(基于内存数据结构) @@ -1452,7 +1527,7 @@ class ExtractionOrchestrator: try: # 保存ID重定向映射 self.id_redirect_map = dedup_details.get("id_redirect", {}) - + # 处理精确匹配的合并记录 exact_merge_map = dedup_details.get("exact_merge_map", {}) for key, info in exact_merge_map.items(): @@ -1466,7 +1541,7 @@ class ExtractionOrchestrator: "merged_count": len(merged_ids), "merged_ids": list(merged_ids) }) - + # 处理模糊匹配的合并记录 fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", []) for record in fuzzy_merge_records: @@ -1486,7 +1561,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}") - + # 处理LLM去重的合并记录 llm_decision_records = dedup_details.get("llm_decision_records", []) for record in llm_decision_records: @@ -1505,7 +1580,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}") - + # 处理消歧记录 disamb_records = dedup_details.get("disamb_records", []) for record in disamb_records: @@ -1520,14 +1595,14 @@ class ExtractionOrchestrator: entity1_type = match.group(2) match.group(3).strip() entity2_type = match.group(4) - + # 提取置信度和原因 conf_match = re.search(r"conf=([0-9.]+)", str(record)) confidence = conf_match.group(1) if conf_match else "unknown" - + reason_match = re.search(r"reason=([^|]+)", str(record)) reason = reason_match.group(1).strip() if reason_match else "" - + self.dedup_disamb_records.append({ "entity_name": entity1_name, "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", @@ -1536,16 +1611,17 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") - - logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") - + + logger.info( + f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") + except Exception as e: logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) async def _analyze_entity_merges( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) @@ -1566,28 +1642,28 @@ class ExtractionOrchestrator: key=lambda x: x.get("merged_count", 0), reverse=True ) - + merge_info = [] for record in sorted_records: merge_info.append({ "main_entity_name": record.get("entity_name", "未知实体"), "merged_count": record.get("merged_count", 1) }) - + return merge_info - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体合并记录") return [] - + except Exception as e: logger.error(f"分析实体合并情况失败: {e}", exc_info=True) return [] async def _analyze_entity_disambiguation( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) @@ -1603,11 +1679,11 @@ class ExtractionOrchestrator: # 直接使用保存的消歧记录 if self.dedup_disamb_records: return self.dedup_disamb_records - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体消歧记录") return [] - + except Exception as e: logger.error(f"分析实体消歧情况失败: {e}", exc_info=True) return [] @@ -1624,7 +1700,7 @@ class ExtractionOrchestrator: """ type_mapping = { "Person": "人物实体节点", - "Organization": "组织实体节点", + "Organization": "组织实体节点", "ORG": "组织实体节点", "Location": "地点实体节点", "LOC": "地点实体节点", @@ -1645,9 +1721,9 @@ class ExtractionOrchestrator: return type_mapping.get(entity_type, f"{entity_type}实体节点") async def _output_relationship_creation_results( - self, - entity_entity_edges: List[EntityEntityEdge], - entity_nodes: List[ExtractedEntityNode] + self, + entity_entity_edges: List[EntityEntityEdge], + entity_nodes: List[ExtractedEntityNode] ): """ 输出关系创建结果 @@ -1659,13 +1735,13 @@ class ExtractionOrchestrator: try: # 创建实体ID到名称的映射 entity_id_to_name = {node.id: node.name for node in entity_nodes} - + # 输出关系创建结果 for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系 source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}") target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}") relation_type = edge.relation_type - + relationship_result = { "result_type": "relationship_creation", "relationship_index": i + 1, @@ -1674,20 +1750,20 @@ class ExtractionOrchestrator: "target_entity": target_name, "relationship_text": f"{source_name} -[{relation_type}]-> {target_name}" } - + await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result) - + except Exception as e: logger.error(f"输出关系创建结果失败: {e}", exc_info=True) async def _send_dedup_progress_callback( - self, - original_entities: int, - final_entities: int, - original_stmt_edges: int, - final_stmt_edges: int, - original_ent_edges: int, - final_ent_edges: int, + self, + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, ): """ 发送去重消歧完成的进度回调,传递具体的去重和消歧效果 @@ -1703,19 +1779,20 @@ class ExtractionOrchestrator: try: # 解析去重消歧报告文件,获取具体的去重和消歧效果 dedup_details = await self._parse_dedup_report() - + # 计算去重效果统计 entities_reduced = original_entities - final_entities stmt_edges_reduced = original_stmt_edges - final_stmt_edges ent_edges_reduced = original_ent_edges - final_ent_edges - + # 构建进度回调数据 dedup_stats = { "entities": { "original_count": original_entities, "final_count": final_entities, "reduced_count": entities_reduced, - "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, + "reduction_rate": round(entities_reduced / original_entities * 100, + 1) if original_entities > 0 else 0, }, "statement_entity_edges": { "original_count": original_stmt_edges, @@ -1734,9 +1811,9 @@ class ExtractionOrchestrator: "total_disambiguations": dedup_details.get("total_disambiguations", 0), } } - + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) - + except Exception as e: logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True) # 即使解析失败,也发送基本的统计信息 @@ -1766,12 +1843,12 @@ class ExtractionOrchestrator: disamb_examples = [] total_merges = 0 total_disambiguations = 0 - + # 处理合并记录 for record in self.dedup_merge_records: merge_count = record.get("merged_count", 0) total_merges += merge_count - + dedup_examples.append({ "type": record.get("type", "未知"), "entity_name": record.get("entity_name", "未知实体"), @@ -1779,30 +1856,31 @@ class ExtractionOrchestrator: "merge_count": merge_count, "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个" }) - + # 处理消歧记录 for record in self.dedup_disamb_records: total_disambiguations += 1 - + # 从消歧类型中提取实体类型信息 disamb_type = record.get("disamb_type", "") entity_name = record.get("entity_name", "未知实体") - + disamb_examples.append({ "entity1_name": entity_name, - "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", + "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", + "").strip() if "vs" in disamb_type else "未知", "entity2_name": entity_name, "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "description": f"{entity_name},消歧区分成功" }) - + return { "dedup_examples": dedup_examples[:5], # 只返回前5个示例 "disamb_examples": disamb_examples[:5], # 只返回前5个示例 "total_merges": total_merges, "total_disambiguations": total_disambiguations, } - + except Exception as e: logger.error(f"获取去重报告失败: {e}", exc_info=True) return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0} @@ -1815,9 +1893,9 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "group_1", - indices: Optional[List[int]] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "group_1", + indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 @@ -1831,7 +1909,7 @@ async def get_chunked_dialogs( """ import json import re - + # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") with open(testdata_path, "r", encoding="utf-8") as f: @@ -1845,7 +1923,7 @@ async def get_chunked_dialogs( else: # 默认使用所有数据 selected_data = test_data - + for data in selected_data: # 解析对话上下文 context_text = data["context"] @@ -1861,7 +1939,7 @@ async def get_chunked_dialogs( if m: y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - + dialog_metadata: Dict[str, Any] = {} if conv_date: dialog_metadata["conversation_date"] = conv_date @@ -1890,7 +1968,7 @@ async def get_chunked_dialogs( end_user_id=end_user_id, metadata=dialog_metadata, ) - + # 创建分块器并处理对话 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, @@ -1913,7 +1991,7 @@ async def get_chunked_dialogs( from app.core.config import settings settings.ensure_memory_output_dir() output_path = settings.get_memory_output_path("chunker_test_output.txt") - + import json with open(output_path, "w", encoding="utf-8") as f: json.dump( @@ -1924,10 +2002,10 @@ async def get_chunked_dialogs( def preprocess_data( - input_path: Optional[str] = None, - output_path: Optional[str] = None, - skip_cleaning: bool = True, - indices: Optional[List[int]] = None + input_path: Optional[str] = None, + output_path: Optional[str] = None, + skip_cleaning: bool = True, + indices: Optional[List[int]] = None ) -> List[DialogData]: """数据预处理 @@ -1946,7 +2024,8 @@ def preprocess_data( ) preprocessor = DataPreprocessor() try: - cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) + cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, + skip_cleaning=skip_cleaning, indices=indices) logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: @@ -1955,9 +2034,9 @@ def preprocess_data( async def get_chunked_dialogs_from_preprocessed( - data: List[DialogData], - chunker_strategy: str = "RecursiveChunker", - llm_client: Optional[Any] = None, + data: List[DialogData], + chunker_strategy: str = "RecursiveChunker", + llm_client: Optional[Any] = None, ) -> List[DialogData]: """从预处理后的数据中生成分块 @@ -1972,31 +2051,31 @@ async def get_chunked_dialogs_from_preprocessed( logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") - + all_chunked_dialogs: List[DialogData] = [] from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, ) - + for dialog_data in data: chunker = DialogueChunker(chunker_strategy, llm_client=llm_client) chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = chunks all_chunked_dialogs.append(dialog_data) - + return all_chunked_dialogs async def get_chunked_dialogs_with_preprocessing( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "default", - user_id: str = "default", - apply_id: str = "default", - indices: Optional[List[int]] = None, - input_data_path: Optional[str] = None, - llm_client: Optional[Any] = None, - skip_cleaning: bool = True, - pruning_config: Optional[Dict] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "default", + user_id: str = "default", + apply_id: str = "default", + indices: Optional[List[int]] = None, + input_data_path: Optional[str] = None, + llm_client: Optional[Any] = None, + skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2020,7 +2099,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path = os.path.join( os.path.dirname(__file__), "../../data", "testdata.json" ) - + # 步骤1: 数据预处理(包含索引筛选) from app.core.config import settings settings.ensure_memory_output_dir() @@ -2030,37 +2109,38 @@ async def get_chunked_dialogs_with_preprocessing( skip_cleaning=skip_cleaning, indices=indices, ) - + # 设置 end_user_id for dd in preprocessed_data: dd.end_user_id = end_user_id - + # 步骤2: 语义剪枝 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) from app.core.memory.models.config_models import PruningConfig - + # 构建剪枝配置 if pruning_config: # 使用传入的配置 config = PruningConfig(**pruning_config) - logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + logger.debug( + f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") else: # 使用默认配置(关闭剪枝) config = None logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") - + pruner = SemanticPruner(config=config, llm_client=llm_client) - + # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None if len(preprocessed_data) == 1 and preprocessed_data[0].context: single_dialog_original_msgs = len(preprocessed_data[0].context.msgs) preprocessed_data = await pruner.prune_dataset(preprocessed_data) - + # 单对话:打印清洗与剪枝信息 if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 @@ -2071,7 +2151,7 @@ async def get_chunked_dialogs_with_preprocessing( ) else: logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") - + # 保存剪枝后的数据 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( @@ -2084,7 +2164,7 @@ async def get_chunked_dialogs_with_preprocessing( logger.error(f"保存剪枝结果失败:{se}") except Exception as e: logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") - + # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( preprocessed_data, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..33838061 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -5,8 +5,11 @@ """ import asyncio +import logging from typing import Any, Dict, List, Tuple +logger = logging.getLogger(__name__) + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.models.message_models import DialogData from app.core.models.base import RedBearModelConfig @@ -48,9 +51,9 @@ class EmbeddingGenerator: return await self.embedder_client.response(texts) # 分批并行处理 - print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") + logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] - print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") + logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") # 并行发送所有批次 batch_results = await asyncio.gather(*[ @@ -62,7 +65,7 @@ class EmbeddingGenerator: for batch_result in batch_results: embeddings.extend(batch_result) - print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") + logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") return embeddings async def generate_statement_embeddings( @@ -77,7 +80,7 @@ class EmbeddingGenerator: Returns: 每个对话的陈述句嵌入向量映射列表 """ - print("\n=== 生成陈述句嵌入向量 ===") + logger.debug("=== 生成陈述句嵌入向量 ===") # 收集所有陈述句 all_statements = [] @@ -102,7 +105,7 @@ class EmbeddingGenerator: stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id stmt_embedding_maps[d_idx][stmt_id] = embedding - print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") + logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") return stmt_embedding_maps async def generate_chunk_embeddings( @@ -117,7 +120,7 @@ class EmbeddingGenerator: Returns: 每个对话的分块嵌入向量映射列表 """ - print("\n=== 生成分块嵌入向量 ===") + logger.debug("=== 生成分块嵌入向量 ===") # 收集所有分块 all_chunks = [] @@ -138,7 +141,7 @@ class EmbeddingGenerator: chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id chunk_embedding_maps[d_idx][chunk_id] = embedding - print(f"为 {len(all_chunks)} 个分块生成了嵌入向量") + logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量") return chunk_embedding_maps async def generate_dialog_embeddings( @@ -172,7 +175,7 @@ class EmbeddingGenerator: Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) """ - print("\n=== 生成所有嵌入向量 ===") + logger.debug("=== 生成所有嵌入向量 ===") # 并发生成陈述句和分块嵌入向量 stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( @@ -183,9 +186,7 @@ class EmbeddingGenerator: # 对话嵌入向量(当前跳过) dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) - print( - f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量" - ) + logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量") return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings @@ -201,7 +202,7 @@ class EmbeddingGenerator: Returns: 更新后的三元组映射列表(实体包含嵌入向量) """ - print("\n=== 生成实体嵌入向量 ===") + logger.debug("=== 生成实体嵌入向量 ===") entity_texts: List[str] = [] entity_refs: List[Any] = [] @@ -219,7 +220,7 @@ class EmbeddingGenerator: entity_refs.append(ent) if not entity_texts: - print("没有找到需要生成嵌入向量的实体") + logger.debug("没有找到需要生成嵌入向量的实体") return triplet_maps # 批量生成嵌入向量 @@ -227,13 +228,13 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): setattr(ent, "name_embedding", emb) - print(f"为 {len(entity_refs)} 个实体生成了嵌入向量") + logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量") return triplet_maps @@ -296,7 +297,7 @@ async def embedding_generation_all( Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) """ - print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") + logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") generator = EmbeddingGenerator(embedding_id) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 443ee36a..5e39ba36 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -188,7 +188,6 @@ async def _process_chunk_summary( response_model=MemorySummaryResponse, ) summary_text = structured.summary.strip() - # Generate title and type for the summary title = None episodic_type = None diff --git a/api/app/core/models/__init__.py b/api/app/core/models/__init__.py index f54afc08..f98d073f 100644 --- a/api/app/core/models/__init__.py +++ b/api/app/core/models/__init__.py @@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto from .llm import RedBearLLM from .embedding import RedBearEmbeddings from .rerank import RedBearRerank +from .generation import RedBearImageGenerator, RedBearVideoGenerator __all__ = [ "RedBearModelConfig", @@ -9,5 +10,7 @@ __all__ = [ "RedBearEmbeddings", "RedBearRerank", "RedBearModelFactory", - "get_provider_llm_class" + "get_provider_llm_class", + "RedBearImageGenerator", + "RedBearVideoGenerator" ] \ No newline at end of file diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 4a453c6b..80117f27 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -67,7 +67,7 @@ class RedBearModelFactory: **config.extra_params } - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 # 这样可以分别控制连接超时和读取超时 import httpx @@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: return ChatOpenAI - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]: if type == ModelType.LLM: return OpenAI elif type == ModelType.CHAT: return ChatOpenAI + else: + raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) elif provider == ModelProvider.DASHSCOPE: return ChatTongyi elif provider == ModelProvider.OLLAMA: diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py index 16af2567..3269e1d0 100644 --- a/api/app/core/models/embedding.py +++ b/api/app/core/models/embedding.py @@ -1,23 +1,190 @@ -from typing import Any, Dict, List, Optional, TypeVar, Callable +from typing import Any, Dict, List, Optional, Union from langchain_core.embeddings import Embeddings -from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory +from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory +from app.models.models_model import ModelProvider + class RedBearEmbeddings(Embeddings): - """Embedding → 完全符合 LangChain Embeddings""" + """统一的 Embedding 类,自动支持多模态(根据 provider 判断)""" + def __init__(self, config: RedBearModelConfig): - self._model = self._create_model(config) self._config = config + self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO + + if self._is_volcano: + # 火山引擎使用 Ark SDK + self._client = self._create_volcano_client(config) + self._model = None + else: + # 其他 provider 使用 LangChain + self._model = self._create_model(config) + self._client = None def _create_model(self, config: RedBearModelConfig) -> Embeddings: - """根据配置创建模型""" + """根据配置创建 LangChain 模型""" embedding_class = get_provider_embedding_class(config.provider) model_params = RedBearModelFactory.get_model_params(config) return embedding_class(**model_params) + + def _create_volcano_client(self, config: RedBearModelConfig): + """创建火山引擎客户端""" + from volcenginesdkarkruntime import Ark + return Ark(api_key=config.api_key, base_url=config.base_url) + # ==================== LangChain 标准接口 ==================== + def embed_documents(self, texts: list[str]) -> list[list[float]]: - return self._model.embed_documents(texts) + """批量文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + contents = [{"type": "text", "text": text} for text in texts] + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + encoding_format="float" + ) + return [response.data.embedding] + else: + # 其他 provider + return self._model.embed_documents(texts) def embed_query(self, text: str) -> List[float]: - return self._model.embed_query(text) + """单个文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + result = self.embed_documents([text]) + return result[0] if result else [] + else: + # 其他 provider + return self._model.embed_query(text) + + # ==================== 多模态扩展方法 ==================== + + def embed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """ + 多模态向量化(仅火山引擎支持) + + Args: + contents: 内容列表,格式: + - 文本: {"type": "text", "text": "..."} + - 图片: {"type": "image_url", "image_url": {"url": "..."}} + - 视频: {"type": "video_url", "video_url": {"url": "..."}} + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + if not self._is_volcano: + raise NotImplementedError( + f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + **kwargs + ) + return [response.data.embedding] + + async def aembed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """异步多模态向量化""" + # 火山引擎 SDK 暂不支持异步,使用同步方法 + return self.embed_multimodal(contents, **kwargs) + + def embed_text(self, text: str, **kwargs) -> List[float]: + """文本向量化(便捷方法)""" + if self._is_volcano: + result = self.embed_multimodal( + [{"type": "text", "text": text}], + **kwargs + ) + return result[0] if result else [] + else: + return self.embed_query(text) + + def embed_image(self, image_url: str, **kwargs) -> List[float]: + """图片向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "image_url", "image_url": {"url": image_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_video(self, video_url: str, **kwargs) -> List[float]: + """视频向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "video_url", "video_url": {"url": video_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_batch( + self, + items: List[Union[str, Dict[str, Any]]], + **kwargs + ) -> List[List[float]]: + """ + 批量向量化(支持混合类型) + + Args: + items: 可以是字符串列表或内容字典列表 + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + # 如果全是字符串,使用标准方法 + if all(isinstance(item, str) for item in items): + return self.embed_documents(items) + + # 如果包含字典,需要多模态支持 + if not self._is_volcano: + raise NotImplementedError( + f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + # 标准化输入格式 + contents = [] + for item in items: + if isinstance(item, str): + contents.append({"type": "text", "text": item}) + elif isinstance(item, dict): + contents.append(item) + else: + raise ValueError(f"不支持的输入类型: {type(item)}") + + return self.embed_multimodal(contents, **kwargs) + + # ==================== 工具方法 ==================== + + def is_multimodal_supported(self) -> bool: + """检查是否支持多模态""" + return self._is_volcano + + def get_provider(self) -> str: + """获取 provider""" + return self._config.provider + + +# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容 +RedBearMultimodalEmbeddings = RedBearEmbeddings diff --git a/api/app/core/models/generation.py b/api/app/core/models/generation.py new file mode 100644 index 00000000..b6388d3f --- /dev/null +++ b/api/app/core/models/generation.py @@ -0,0 +1,344 @@ +""" +图片和视频生成模型封装 + +支持的 Provider: +- Volcano (火山引擎): 使用 volcenginesdkarkruntime +- OpenAI: 使用 openai SDK +""" +from typing import Any, Dict, Optional + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.images.images import ( + SequentialImageGenerationOptions, + ContentGenerationTool, + OptimizePromptOptions +) + +from app.core.models.base import RedBearModelConfig +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelProvider + + +class RedBearImageGenerator: + """图片生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + # elif provider == ModelProvider.OPENAI: + # from openai import OpenAI + # return OpenAI(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的图片生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + sequential_image_generation: Optional[str] = None, + sequential_image_generation_options: Optional[Dict] = None, + tools: Optional[list] = None, + optimize_prompt_options: Optional[Dict] = None, + stream: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + prompt: 提示词 + image: 参考图片URL或URL列表(图文生图/多图融合) + size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素) + output_format: 输出格式,如 "png", "jpg" + response_format: 返回格式,"url" 或 "b64_json" + watermark: 是否添加水印 + sequential_image_generation: 组图生成模式,"auto" 或 "disabled" + sequential_image_generation_options: 组图生成选项,如 {"max_images": 4} + tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图 + optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"} + stream: 是否使用流式生成 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = { + "model": self._config.model_name, + "prompt": prompt, + "size": size, + "output_format": output_format, + "response_format": response_format, + "watermark": watermark, + } + + if image is not None: + params["image"] = image + + if sequential_image_generation: + params["sequential_image_generation"] = sequential_image_generation + if sequential_image_generation_options: + params["sequential_image_generation_options"] = SequentialImageGenerationOptions( + **sequential_image_generation_options + ) + + if tools: + params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools] + + if optimize_prompt_options: + params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options) + + if stream: + params["stream"] = True + + params.update(kwargs) + response = self._client.images.generate(**params) + + # elif provider == ModelProvider.OPENAI: + # response = self._client.images.generate( + # model=self._config.model_name, + # prompt=prompt, + # size=size, + # n=n, + # **kwargs + # ) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + **kwargs + ) -> Dict[str, Any]: + """异步生成图片""" + return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs) + + +class RedBearVideoGenerator: + """视频生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的视频生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image_url: Optional[str] = None, + first_frame_url: Optional[str] = None, + last_frame_url: Optional[str] = None, + reference_images: Optional[list] = None, + draft_task_id: Optional[str] = None, + duration: Optional[int] = None, + frames: Optional[int] = None, + ratio: Optional[str] = None, + resolution: Optional[str] = None, + generate_audio: bool = False, + watermark: bool = False, + camera_fixed: bool = False, + seed: Optional[int] = None, + return_last_frame: bool = False, + service_tier: str = "default", + execution_expires_after: Optional[int] = None, + draft: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + prompt: 提示词 + image_url: 首帧图片URL(图生视频-基于首帧) + first_frame_url: 首帧图片URL(图生视频-基于首尾帧) + last_frame_url: 尾帧图片URL(图生视频-基于首尾帧) + reference_images: 参考图片URL列表(图生视频-基于参考图) + draft_task_id: Draft任务ID(基于Draft生成正式视频) + duration: 视频时长(秒),与frames二选一 + frames: 视频帧数,与duration二选一 + ratio: 视频比例,如 "16:9", "9:16", "adaptive" + resolution: 视频分辨率,如 "720p", "1080p" + generate_audio: 是否生成音频 + watermark: 是否添加水印 + camera_fixed: 是否固定镜头 + seed: 随机种子 + return_last_frame: 是否返回最后一帧 + service_tier: 服务层级,"default" 或 "flex"(离线推理) + execution_expires_after: 任务过期时间(秒) + draft: 是否生成样片 + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID,需要轮询获取结果) + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + content = [{"type": "text", "text": prompt}] + + if draft_task_id: + content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}] + else: + if image_url: + content.append({"type": "image_url", "image_url": {"url": image_url}}) + + if first_frame_url: + content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"}) + if last_frame_url: + content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"}) + + if reference_images: + for ref_url in reference_images: + content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"}) + + params = {"model": self._config.model_name, "content": content, "watermark": watermark} + + if duration: + params["duration"] = duration + if frames: + params["frames"] = frames + if ratio: + params["ratio"] = ratio + if resolution: + params["resolution"] = resolution + if generate_audio: + params["generate_audio"] = generate_audio + if camera_fixed: + params["camera_fixed"] = camera_fixed + if seed is not None: + params["seed"] = seed + if return_last_frame: + params["return_last_frame"] = return_last_frame + if service_tier != "default": + params["service_tier"] = service_tier + if execution_expires_after: + params["execution_expires_after"] = execution_expires_after + if draft: + params["draft"] = draft + + params.update(kwargs) + response = self._client.content_generation.tasks.create(**params) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image_url: Optional[str] = None, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """异步生成视频""" + return self.generate(prompt, image_url=image_url, duration=duration, **kwargs) + + def get_task_status(self, task_id: str) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + task_id: 任务ID + + Returns: + 任务状态信息 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + response = self._client.content_generation.tasks.get(task_id=task_id) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + async def aget_task_status(self, task_id: str) -> Dict[str, Any]: + """异步查询任务状态""" + return self.get_task_status(task_id) + + def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]: + """ + 查询视频生成任务列表 + + Args: + page_size: 每页数量 + status: 任务状态筛选,如 "succeeded", "failed", "pending" + **kwargs: 其他参数 + + Returns: + 任务列表 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = {"page_size": page_size} + if status: + params["status"] = status + params.update(kwargs) + response = self._client.content_generation.tasks.list(**params) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def delete_task(self, task_id: str) -> None: + """ + 删除或取消视频生成任务 + + Args: + task_id: 任务ID + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + self._client.content_generation.tasks.delete(task_id=task_id) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml new file mode 100644 index 00000000..24609f5a --- /dev/null +++ b/api/app/core/models/scripts/volcano_models.yaml @@ -0,0 +1,334 @@ +provider: volcano +models: +# Doubao-Seed 2.0 系列 +- name: doubao-seed-2-0-pro-260215 + type: chat + provider: volcano + description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-lite-260215 + type: chat + provider: volcano + description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-mini-260215 + type: chat + provider: volcano + description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-code-preview-260215 + type: chat + provider: volcano + description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +# Doubao-Seed 1.x 系列 +- name: doubao-seed-1-8-251228 + type: chat + provider: volcano + description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-251015 + type: chat + provider: volcano + description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-lite-251015 + type: chat + provider: volcano + description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-flash-250828 + type: chat + provider: volcano + description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-code-preview-251028 + type: chat + provider: volcano + description: 面向Agentic编程任务进行了深度优化。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +- name: doubao-seed-1-6-vision-250815 + type: chat + provider: volcano + description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +# Doubao 1.5 系列 +- name: doubao-1-5-vision-pro-32k-250115 + type: chat + provider: volcano + description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +- name: doubao-1-5-pro-32k-250115 + type: chat + provider: volcano + description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-1-5-lite-32k-250115 + type: chat + provider: volcano + description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +# Doubao-Seedance 视频生成系列 +- name: doubao-seedance-1-5-pro-251215 + type: video + provider: volcano + description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-250528 + type: video + provider: volcano + description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-fast-251015 + type: video + provider: volcano + description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-lite-i2v-250428 + type: video + provider: volcano + description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + - 图生视频 + logo: volcano + +- name: doubao-seedance-1-0-lite-t2v-250428 + type: video + provider: volcano + description: 基于文本提示词生成视频 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 视频生成 + - 文生视频 + logo: volcano + +# Doubao-Seedream 图像生成系列 +- name: doubao-seedream-5-0-260128 + type: image + provider: volcano + description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-5-251128 + type: image + provider: volcano + description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-0-250828 + type: image + provider: volcano + description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-3-0-t2i-250415 + type: image + provider: volcano + description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 图像生成 + - 文生图 + logo: volcano + +# Doubao 翻译系列 +- name: doubao-seed-translation-250915 + type: chat + provider: volcano + description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 翻译模型 + logo: volcano + +# Doubao Embedding 系列 +- name: doubao-embedding-vision-251215 + type: embedding + provider: volcano + description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 向量模型 + - 多模态模型 + logo: volcano diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 198d1473..386920e0 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel): class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): super().__init__(index_name.lower()) - # self.embeddings = XinferenceEmbeddings( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port - # model_uid="bge-m3" # replace model_uid with the model UID return from launching the model - # ) - # Remove debug printing to avoid leaking sensitive information - # print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base) + + # 初始化 Embedding 模型(自动支持火山引擎多模态) self.embeddings = RedBearEmbeddings(RedBearModelConfig( model_name=embedding_config.model_name, provider=embedding_config.provider, api_key=embedding_config.api_key, base_url=embedding_config.api_base )) - # self.reranker = XinferenceRerank( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), - # model_uid="bge-reranker-large" - # ) - # Remove debug printing to avoid leaking sensitive information - # print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base) + self.is_multimodal_embedding = self.embeddings.is_multimodal_supported() + self.reranker = RedBearRerank(RedBearModelConfig( model_name=reranker_config.model_name, provider=reranker_config.provider, @@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector): def add_chunks(self, chunks: list[DocumentChunk], **kwargs): # 实现 Elasticsearch 保存向量 texts = [chunk.page_content for chunk in chunks] - embeddings = self.embeddings.embed_documents(list(texts)) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + embeddings = self.embeddings.embed_batch(texts) + else: + embeddings = self.embeddings.embed_documents(list(texts)) self.create(chunks, embeddings, **kwargs) def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): @@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector): updated count. """ indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - chunk.vector = self.embeddings.embed_query(chunk.page_content) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + chunk.vector = self.embeddings.embed_text(chunk.page_content) + else: + chunk.vector = self.embeddings.embed_query(chunk.page_content) body = { "script": { @@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: """Search the nearest neighbors to a vector.""" - query_vector = self.embeddings.embed_query(query) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + query_vector = self.embeddings.embed_text(query) + else: + query_vector = self.embeddings.embed_query(query) top_k = kwargs.get("top_k", 1024) score_threshold = float(kwargs.get("score_threshold") or 0.3) indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" diff --git a/api/app/core/storage/base.py b/api/app/core/storage/base.py index 8ab0fcde..09824c3f 100644 --- a/api/app/core/storage/base.py +++ b/api/app/core/storage/base.py @@ -109,17 +109,13 @@ class StorageBackend(ABC): pass @abstractmethod - async def get_url(self, file_key: str, expires: int = 3600) -> str: - """ - Get an access URL for the file. - - Args: - file_key: Unique identifier for the file in the storage system. - expires: URL validity period in seconds (default: 1 hour). - - Returns: - URL for accessing the file. - """ + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: + """Get an access URL for the file.""" pass async def get_permanent_url(self, file_key: str) -> Optional[str]: diff --git a/api/app/core/storage/local.py b/api/app/core/storage/local.py index 4b8ae829..13adfc20 100644 --- a/api/app/core/storage/local.py +++ b/api/app/core/storage/local.py @@ -210,7 +210,12 @@ class LocalStorage(StorageBackend): cause=e, ) - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: """ Get an access URL for the file. @@ -220,6 +225,7 @@ class LocalStorage(StorageBackend): Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (not used for local storage). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A relative URL path for accessing the file. diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 27669ffa..1db86fef 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK. import io import logging +import urllib.parse from typing import AsyncIterator, Optional import oss2 @@ -242,24 +243,33 @@ class OSSStorage(StorageBackend): logger.error(f"Failed to check file existence in OSS {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: - url = self.bucket.sign_url("GET", file_key, expires) + params = {} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" + url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/core/storage/s3.py b/api/app/core/storage/s3.py index c7b33ffe..f156f4a7 100644 --- a/api/app/core/storage/s3.py +++ b/api/app/core/storage/s3.py @@ -6,6 +6,7 @@ using the boto3 SDK. """ import io +import urllib.parse import logging from typing import AsyncIterator, Optional @@ -352,31 +353,37 @@ class S3Storage(StorageBackend): logger.error(f"Failed to check file existence in S3 {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: + params = {"Bucket": self.bucket_name, "Key": file_key} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" url = self.client.generate_presigned_url( "get_object", - Params={ - "Bucket": self.bucket_name, - "Key": file_key, - }, + Params=params, ExpiresIn=expires, ) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py index 49321b89..2e24d085 100644 --- a/api/app/core/workflow/adapters/base_adapter.py +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -9,7 +9,7 @@ from typing import Any from pydantic import BaseModel, Field -from app.core.workflow.adapters.errors import ExceptionDefineition +from app.core.workflow.adapters.errors import ExceptionDefinition from app.schemas.workflow_schema import ( EdgeDefinition, NodeDefinition, @@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class WorkflowImportResult(BaseModel): @@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class BasePlatformAdapter(ABC): diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 467beb07..4fa9508b 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -9,9 +9,9 @@ from urllib.parse import quote from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.errors import ( - UnsupportVariableType, - UnknowModelWarning, - ExceptionDefineition, + UnsupportedVariableType, + UnknownModelWarning, + ExceptionDefinition, ExceptionType ) from app.core.workflow.nodes.assigner.config import AssignmentItem @@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import ( HttpFormData, HttpTimeOutConfig, HttpRetryConfig, - HttpErrorDefaultTamplete, + HttpErrorDefaultTemplate, HttpErrorHandleConfig ) from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig @@ -108,7 +108,7 @@ class DifyConverter(BaseConverter): try: return config.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, @@ -138,7 +138,7 @@ class DifyConverter(BaseConverter): var_selector = mapping.get(var_selector, var_selector) return var_selector - def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + def _process_list_variable_literal(self, variable_selector: list) -> str | None: if not self.process_var_selector(".".join(variable_selector)): return None return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" @@ -269,7 +269,7 @@ class DifyConverter(BaseConverter): var_type = self.variable_type_map(var["type"]) if not var_type: self.errors.append( - UnsupportVariableType( + UnsupportedVariableType( scope=node["id"], name=var["variable"], var_type=var["type"], @@ -281,7 +281,7 @@ class DifyConverter(BaseConverter): if var_type in ["file", "array[file]"]: self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.VARIABLE, node_id=node["id"], node_name=node_data["title"], @@ -311,7 +311,7 @@ class DifyConverter(BaseConverter): def convert_question_classifier_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -327,7 +327,7 @@ class DifyConverter(BaseConverter): ) result = QuestionClassifierNodeConfig.model_construct( - input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), + input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), categories=categories, ).model_dump() @@ -337,13 +337,13 @@ class DifyConverter(BaseConverter): def convert_llm_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") ) ) - context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + context = self._process_list_variable_literal(node_data["context"]["variable_selector"]) memory = MemoryWindowSetting( enable=bool(node_data.get("memory")), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), @@ -367,7 +367,7 @@ class DifyConverter(BaseConverter): ) ) vision = node_data["vision"]["enabled"] - vision_input = self._process_list_variable_litearl( + vision_input = self._process_list_variable_literal( node_data["vision"]["configs"]["variable_selector"] ) if vision else None result = LLMNodeConfig.model_construct( @@ -433,7 +433,7 @@ class DifyConverter(BaseConverter): conditions.append( LoopConditionDetail.model_construct( operator=self.convert_compare_operator(condition["comparison_operator"]), - left=self._process_list_variable_litearl(condition["variable_selector"]), + left=self._process_list_variable_literal(condition["variable_selector"]), right=self.trans_variable_format( right_value ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( @@ -453,7 +453,7 @@ class DifyConverter(BaseConverter): right_input_type = variable["value_type"] right_value_type = self.variable_type_map(variable["var_type"]) if right_input_type == ValueInputType.VARIABLE: - right_value = self._process_list_variable_litearl(variable.get("value", "")) + right_value = self._process_list_variable_literal(variable.get("value", "")) else: right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) loop_variables.append( @@ -475,10 +475,10 @@ class DifyConverter(BaseConverter): def convert_iteration_node_config(self, node: dict) -> dict: node_data = node["data"] result = IterationNodeConfig.model_construct( - input=self._process_list_variable_litearl(node_data["iterator_selector"]), + input=self._process_list_variable_literal(node_data["iterator_selector"]), parallel=node_data["is_parallel"], parallel_count=node_data["parallel_nums"], - output=self._process_list_variable_litearl(node_data["output_selector"]), + output=self._process_list_variable_literal(node_data["output_selector"]), output_type=self.variable_type_map(node_data.get("output_type")), flatten=node_data["flatten_output"], ).model_dump() @@ -494,8 +494,8 @@ class DifyConverter(BaseConverter): continue assignments.append( AssignmentItem( - variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), - value=self._process_list_variable_litearl( + variable_selector=self._process_list_variable_literal(assignment["variable_selector"]), + value=self._process_list_variable_literal( assignment["value"] ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], operation=self.convert_assignment_operator(assignment["operation"]) @@ -514,7 +514,7 @@ class DifyConverter(BaseConverter): input_variables.append( InputVariable.model_construct( name=input_variable["variable"], - variable=self._process_list_variable_litearl(input_variable["value_selector"]), + variable=self._process_list_variable_literal(input_variable["value_selector"]), ) ) @@ -570,7 +570,7 @@ class DifyConverter(BaseConverter): else: if node_data["body"]["data"]: body_content = (node_data["body"]["data"][0].get("value") or - self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) + self._process_list_variable_literal(node_data["body"]["data"][0].get("file"))) else: body_content = "" @@ -585,7 +585,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -603,7 +603,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -625,7 +625,7 @@ class DifyConverter(BaseConverter): default_header = var["value"] elif var["key"] == "status_code": default_status_code = var["value"] - default_value = HttpErrorDefaultTamplete( + default_value = HttpErrorDefaultTemplate( body=default_body, headers=default_header, status_code=default_status_code, @@ -668,7 +668,7 @@ class DifyConverter(BaseConverter): for variable in node_data["variables"]: mapping.append(VariablesMappingConfig.model_construct( name=variable["variable"], - value=self._process_list_variable_litearl(variable["value_selector"]) + value=self._process_list_variable_literal(variable["value_selector"]) )) result = JinjaRenderNodeConfig.model_construct( template=node_data["template"], @@ -679,14 +679,14 @@ class DifyConverter(BaseConverter): def convert_knowledge_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, detail=f"Please reconfigure the Knowledge Retrieval node.", )) result = KnowledgeRetrievalNodeConfig.model_construct( - query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + query=self._process_list_variable_literal(node_data["query_variable_selector"]), ).model_dump() self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) @@ -695,7 +695,7 @@ class DifyConverter(BaseConverter): def convert_parameter_extractor_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -712,7 +712,7 @@ class DifyConverter(BaseConverter): ) ) result = ParameterExtractorNodeConfig.model_construct( - text=self._process_list_variable_litearl(node_data["query"]), + text=self._process_list_variable_literal(node_data["query"]), params=params, prompt=node_data.get("instruction") ).model_dump() @@ -727,14 +727,14 @@ class DifyConverter(BaseConverter): group_type = {} if not advanced_settings or not advanced_settings["group_enabled"]: group_variables = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in node_data["variables"] ] group_type["output"] = node_data["output_type"] else: for group in advanced_settings["groups"]: group_variables[group["group_name"]] = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in group["variables"] ] group_type[group["group_name"]] = group["output_type"] @@ -751,7 +751,7 @@ class DifyConverter(BaseConverter): def convert_tool_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index 10397ad0..abd95408 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import ( WorkflowParserResult ) from app.core.workflow.adapters.dify.converter import DifyConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ( NodeDefinition, @@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): if not all(field in self.config for field in require_fields): return False if self.config.get("app", {}).get("mode") == "workflow": - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.PLATFORM, detail="workflow mode is not supported" )) @@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): edge = self._convert_edge(edge) if edge: self.edges.append(edge) - # + for variable in self.config.get("workflow").get("conversation_variables"): con_var = self._convert_variable(variable) if variable: self.conv_variables.append(con_var) - # + # for variables in config.get("workflow").get("environment_variables"): # variable = self._convert_variable(variables) # conv_variables.append(variable) @@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "y": node["position"]["y"] + position["y"] } self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, detail="parent cycle node not found" @@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): node_data = node["data"] converter = self.get_node_convert(node_type) if node_type == NodeType.UNKNOWN: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): )) return converter(node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: try: - source = edge["source"] target = edge["target"] label = None @@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): label=label, ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}", )) @@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): description=variable.get("description") ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}", diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py index c0340a5e..cb743c68 100644 --- a/api/app/core/workflow/adapters/errors.py +++ b/api/app/core/workflow/adapters/errors.py @@ -18,7 +18,7 @@ class ExceptionType(StrEnum): UNKNOWN = "unknown" -class ExceptionDefineition(BaseModel): +class ExceptionDefinition(BaseModel): type: ExceptionType detail: str @@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel): name: str | None = None -class UnknowModelWarning(ExceptionDefineition): +class UnknownModelWarning(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id, node_name, model_name): @@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition): ) -class UnknowError(ExceptionDefineition): +class UnknownError(ExceptionDefinition): type: ExceptionType = ExceptionType.UNKNOWN def __init__(self, detail: str, **kwargs): super().__init__(detail=detail, **kwargs) -class UnsupportPlatform(ExceptionDefineition): +class UnsupportedPlatform(ExceptionDefinition): type: ExceptionType = ExceptionType.PLATFORM def __init__(self, platform: str): - super().__init__(detail=f"Unsupport platform {platform}") + super().__init__(detail=f"Unsupported platform {platform}") -class UnsupportVariableType(ExceptionDefineition): +class UnsupportedVariableType(ExceptionDefinition): type: ExceptionType = ExceptionType.VARIABLE def __init__(self, scope, name, var_type: str, **kwargs): - super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs) -class InvalidConfiguration(ExceptionDefineition): +class InvalidConfiguration(ExceptionDefinition): type: ExceptionType = ExceptionType.CONFIG def __init__(self): super().__init__(detail="Invalid workflow configuration format") -class UnsupportNodeType(ExceptionDefineition): +class UnsupportedNodeType(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id: str, node_type: str): - super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") + super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py index 3516cb58..a2608a01 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import ( BasePlatformAdapter, WorkflowParserResult ) -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition @@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: node_type = self.map_node_type(node["type"]) if node_type == NodeType.UNKNOWN: - self.errors.append(UnsupportNodeType( + self.errors.append(UnsupportedNodeType( node_id=node_id, node_type=node["type"] )) @@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): return NodeDefinition(**node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, node_name=node_name, @@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: try: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"edge {edge.get('id')} skipped: source or target node not found" )) return None return EdgeDefinition(**edge) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}" )) @@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: return VariableDefinition(**variable) except Exception as e: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}" diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py index 031c7025..e96e0bf2 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py @@ -1,6 +1,6 @@ # -*- coding: UTF-8 -*- from app.core.workflow.adapters.base_converter import BaseConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.configs import ( StartNodeConfig, @@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter): try: return config_cls.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 674c45d0..daef6e82 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -7,7 +7,7 @@ import re import uuid from collections import defaultdict from functools import lru_cache -from typing import Any, Iterable +from typing import Any, Iterable, Callable from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, END @@ -41,48 +41,31 @@ class GraphBuilder: self, workflow_config: dict[str, Any], stream: bool = False, - subgraph: bool = False, + cycle: str = '', variable_pool: VariablePool | None = None ): self.workflow_config = workflow_config self.stream = stream - self.subgraph = subgraph + self.cycle = cycle self.start_node_id: str | None = None - self.node_map = {node["id"]: node for node in self.nodes} + self.node_map: dict[str, dict] = {} self.end_node_map: dict[str, StreamOutputConfig] = {} - self._find_upstream_activation_dep = lru_cache( - maxsize=len(self.nodes) * 2 - )(self._find_upstream_activation_dep) + self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep if variable_pool: self.variable_pool = variable_pool else: self.variable_pool = VariablePool() - self.graph = StateGraph(WorkflowState) - self.add_nodes() - self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) - self.end_nodes = [ - node - for node in self.nodes - if node.get("type") == "end" and node.get("id") in self.reachable_nodes - ] - self.add_edges() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. - + self.graph: StateGraph | None = None + self.nodes: list = [] + self.edges: list = [] + self.reachable_nodes: set[str] | None = None + self.end_nodes: list[dict] = [] self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._build_reverse_adj() - self._analyze_end_node_output() - - @property - def nodes(self) -> list[dict[str, Any]]: - return self.workflow_config.get("nodes", []) - - @property - def edges(self) -> list[dict[str, Any]]: - return self.workflow_config.get("edges", []) + self._adj: dict[str, list[str]] = defaultdict(list) def get_node_type(self, node_id: str) -> str: """Retrieve the type of node given its ID. @@ -108,13 +91,14 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _build_reverse_adj(self): + def _build_adj(self): for edge in self.edges: if edge["source"] not in self.reachable_nodes: continue self._reverse_adj[edge.get("target")].append({ "id": edge["source"], "branch": edge.get("label") }) + self._adj[edge.get("source")].append(edge["target"]) def _find_upstream_activation_dep( self, @@ -302,22 +286,13 @@ class GraphBuilder: """ for node in self.nodes: node_type = node.get("type") - if node_type == NodeType.NOTES: - continue node_id = node.get("id") - cycle_node = node.get("cycle") - if cycle_node: - # Nodes within a loop subgraph are constructed by CycleGraphNode - if not self.subgraph: - continue - - # Record start and end node IDs - if node_type in [NodeType.START, NodeType.CYCLE_START]: - self.start_node_id = node_id + if node_id not in self.reachable_nodes: + continue # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph - node_instance = NodeFactory.create_node(node, self.workflow_config) + node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id]) if node_type in BRANCH_NODES: @@ -390,6 +365,8 @@ class GraphBuilder: for edge in self.edges: source = edge.get("source") target = edge.get("target") + if source not in self.reachable_nodes or target not in self.reachable_nodes: + continue condition = edge.get("condition") edge_type = edge.get("type") @@ -411,11 +388,12 @@ class GraphBuilder: # Add conditional edges for source_node, branches in conditional_edges.items(): def make_router(src, branch_list): - """reate a router function for each source node that routes to a NOP node for later merging.""" + """Create a router function for each source node that routes to a NOP node for later merging.""" def make_branch_node(node_name, targets): def node(s): - # NOTE: NOP NODE MUST NOT MODIFY STATE + # NOTE: NOP NODE USED FOR ROUTING ONLY. + # MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS. return { "activate": { node_id: s["activate"][node_name] @@ -502,14 +480,52 @@ class GraphBuilder: logger.debug(f"Added waiting edge: {sources} -> {target}") # Connect End nodes to the global END node - for end_node in self.end_nodes: - end_node_id = end_node.get("id") - if end_node_id: - self.graph.add_edge(end_node_id, END) - logger.debug(f"Added edge: {end_node_id} -> END") + for node in self.reachable_nodes: + if not self._adj[node]: + self.graph.add_edge(node, END) return def build(self) -> CompiledStateGraph: + nodes = self.workflow_config.get("nodes", []) + edges = self.workflow_config.get("edges", []) + + for node in nodes: + if (node.get("cycle") or '') == self.cycle: + node_type = node.get("type") + if node_type in [NodeType.START, NodeType.CYCLE_START]: + self.start_node_id = node.get("id") + elif node_type == NodeType.NOTES: + continue + self.nodes.append(node) + self.node_map[node.get("id")] = node + + for edge in edges: + source_in = edge.get("source") in self.node_map + target_in = edge.get("target") in self.node_map + if source_in ^ target_in: + raise ValueError( + f"Cycle node is connected to external node, " + f"source: {edge.get('source')}, target: {edge.get('target')}" + ) + + if source_in and target_in: + self.edges.append(edge) + + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] + self._build_adj() + self._find_upstream_activation_dep: Callable = lru_cache( + maxsize=len(self.nodes)*2 + )(self._find_upstream_activation_dep) + + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.add_edges() + + self._analyze_end_node_output() checkpointer = InMemorySaver() - self.graph = self.graph.compile(checkpointer=checkpointer) - return self.graph + return self.graph.compile(checkpointer=checkpointer) diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index e5a03c1c..be0c957a 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -2,6 +2,7 @@ # Author: Eternity # @Email: 1533512157@qq.com # @Time : 2026/2/10 13:33 +from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.variable_pool import VariablePool @@ -9,6 +10,7 @@ class WorkflowResultBuilder: def build_final_output( self, result: dict, + execution_context: ExecutionContext, variable_pool: VariablePool, elapsed_time: float, final_output: str, @@ -26,6 +28,8 @@ class WorkflowResultBuilder: - "node_outputs" (dict): Outputs of executed nodes. - "messages" (list): Conversation messages exchanged during execution. - "error" (str, optional): Error message if any node failed. + execution_context (ExecutionContext): The execution context containing metadata like + execution ID, workspace ID, and user ID.) variable_pool (VariablePool): Variable Pool elapsed_time (float): Total execution time in seconds. final_output (Any): The aggregated or final output content of the workflow @@ -48,18 +52,23 @@ class WorkflowResultBuilder: """ node_outputs = result.get("node_outputs", {}) token_usage = self.aggregate_token_usage(node_outputs) - conversation_id = variable_pool.get_value("sys.conversation_id") + conversation_vars = {} + sys_vars = {} + + if variable_pool: + conversation_vars = variable_pool.get_all_conversation_vars() + sys_vars = variable_pool.get_all_system_vars() return { "status": "completed" if success else "failed", "output": final_output, "variables": { - "conv": variable_pool.get_all_conversation_vars(), - "sys": variable_pool.get_all_system_vars() + "conv": conversation_vars, + "sys": sys_vars }, "node_outputs": node_outputs, "messages": result.get("messages", []), - "conversation_id": conversation_id, + "conversation_id": execution_context.conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, "error": result.get("error"), diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index e4bf65af..036ce0e8 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,14 +12,29 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + conversation_id: str + memory_storage_type: str + user_rag_memory_id: str checkpoint_config: RunnableConfig @classmethod - def create(cls, execution_id: str, workspace_id: str, user_id: str): + def create( + cls, + execution_id: str, + workspace_id: str, + user_id: str, + conversation_id: str, + memory_storage_type: str, + user_rag_memory_id: str + ): return cls( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + conversation_id=conversation_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id, + checkpoint_config=RunnableConfig( configurable={ "thread_id": uuid.uuid4(), diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 0a4a1463..2da0d3a8 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -33,6 +33,8 @@ class WorkflowState(dict): "workspace_id", "user_id", "activate", + "memory_storage_type", + "user_rag_memory_id" }) __optional_keys__ = frozenset({ "error", @@ -62,6 +64,9 @@ class WorkflowState(dict): # node activate status activate: Annotated[dict[str, bool], merge_activate_state] + memory_storage_type: str + user_rag_memory_id: str + class WorkflowStateManager: def create_initial_state( @@ -85,7 +90,9 @@ class WorkflowStateManager: looping=0, activate={ start_node_id: True - } + }, + memory_storage_type=execution_context.memory_storage_type, + user_rag_memory_id=execution_context.user_rag_memory_id ) @staticmethod diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index 6685a49e..dcc92fdb 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -3,7 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 15:11 import re -from queue import Queue +from collections import deque from typing import AsyncGenerator from pydantic import BaseModel, Field, PrivateAttr @@ -256,7 +256,7 @@ class StreamOutputCoordinator: def __init__(self): self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None - self.output_queue: Queue = Queue() + self.output_queue: deque[str] = deque() self.processed_outputs = [] def initialize_end_outputs( @@ -266,7 +266,7 @@ class StreamOutputCoordinator: self.end_outputs = end_node_map self.processed_outputs = [] self.activate_end = None - self.output_queue = Queue() + self.output_queue = deque() @property def current_activate_end_info(self): @@ -296,13 +296,13 @@ class StreamOutputCoordinator: scope (str): The node ID or scope that has completed execution. status (str | None): Optional status of the node (used for branch/control nodes). """ - for node in self.end_outputs.keys(): + for node in self.end_outputs: self.end_outputs[node].update_activate(scope, status) if self.end_outputs[node].activate and node not in self.processed_outputs: - self.output_queue.put(node) + self.output_queue.append(node) self.processed_outputs.append(node) - if self.activate_end is None and not self.output_queue.empty(): - self.activate_end = self.output_queue.get_nowait() + if self.activate_end is None and self.output_queue: + self.activate_end = self.output_queue.popleft() async def emit_activate_chunk( self, @@ -414,8 +414,8 @@ class StreamOutputCoordinator: async for msg_event in self.emit_activate_chunk(variable_pool, force=True): yield msg_event - if not self.output_queue.empty(): - self.activate_end = self.output_queue.get_nowait() + if self.output_queue: + self.activate_end = self.output_queue.popleft() # Move to next active End node if current one is done if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index cf6f4a7b..60f1257e 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE -from app.core.workflow.variable.variable_objects import T, create_variable_instance +from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable logger = logging.getLogger(__name__) @@ -373,6 +373,16 @@ class VariablePool: def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) + def is_file_variable(self, selector): + variable_struct = self.get_instance(selector, default=None, strict=False) + if variable_struct is None: + return False + if isinstance(variable_struct, FileVariable): + return True + elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: + return True + return False + def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c9ed6e65..0a820826 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -3,6 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 13:51 import datetime +import time import logging from typing import Any @@ -82,13 +83,15 @@ class WorkflowExecutor: CompiledStateGraph: The compiled and ready-to-run state graph. """ logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") + start_time = time.time() builder = GraphBuilder( self.workflow_config, stream=stream, ) + + self.graph = builder.build() self.start_node_id = builder.start_node_id self.variable_pool = builder.variable_pool - self.graph = builder.build() self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.event_handler = EventStreamHandler( @@ -96,7 +99,8 @@ class WorkflowExecutor: variable_pool=self.variable_pool, execution_id=self.execution_context.execution_id ) - logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") + logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, " + f"cost: {time.time() - start_time:.4f}s") return self.graph @@ -134,94 +138,12 @@ class WorkflowExecutor: return event.get("data") return self.result_builder.build_final_output( {"error": "Workflow execution did not end as expected"}, + self.execution_context, self.variable_pool, (datetime.datetime.now() - start).total_seconds(), "", success=False ) - # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") - # - # start_time = datetime.datetime.now() - # - # # Execute the workflow - # try: - # # Build the workflow graph - # graph = self.build_graph() - # - # # Initialize the variable pool with input data - # await self.variable_initializer.initialize( - # variable_pool=self.variable_pool, - # input_data=input_data, - # execution_context=self.execution_context - # ) - # initial_state = self.state_manager.create_initial_state( - # workflow_config=self.workflow_config, - # input_data=input_data, - # execution_context=self.execution_context, - # start_node_id=self.start_node_id - # ) - # - # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) - # - # # Aggregate output from all End nodes - # full_content = '' - # for end_id in self.stream_coordinator.end_outputs.keys(): - # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) - # - # # Append messages for user and assistant - # if input_data.get("files"): - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "user", - # "content": input_data.get("files") - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # else: - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # # Calculate elapsed time - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.info( - # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") - # - # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) - # - # except Exception as e: - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", - # exc_info=True) - # return { - # "status": "failed", - # "error": str(e), - # "output": None, - # "node_outputs": {}, - # "elapsed_time": elapsed_time, - # "token_usage": None - # } async def execute_stream( self, @@ -255,7 +177,7 @@ class WorkflowExecutor: "data": { "execution_id": self.execution_context.execution_id, "workspace_id": self.execution_context.workspace_id, - "conversation_id": input_data.get("conversation_id"), + "conversation_id": self.execution_context.conversation_id, "timestamp": int(start_time.timestamp() * 1000) } } @@ -376,6 +298,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -396,6 +319,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -409,7 +333,9 @@ async def execute_workflow( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ) -> dict[str, Any]: """ Execute a workflow (convenience function, non-streaming). @@ -420,6 +346,8 @@ async def execute_workflow( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Returns: dict: Workflow execution result. @@ -427,7 +355,10 @@ async def execute_workflow( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + conversation_id=input_data.get("conversation_id"), + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, @@ -441,7 +372,9 @@ async def execute_workflow_stream( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ): """ Execute a workflow in streaming mode (convenience function). @@ -452,6 +385,8 @@ async def execute_workflow_stream( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Yields: dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. @@ -459,7 +394,10 @@ async def execute_workflow_stream( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + conversation_id=input_data.get("conversation_id"), + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 8959e27c..7b146a9c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -64,9 +64,7 @@ class AgentNode(BaseNode): if not release: raise ValueError(f"Agent 不存在: {agent_id}") - - return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 4c897d5a..f5bdf000 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0e3fecee..8567ebbe 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -28,7 +28,7 @@ class BaseNode(ABC): All node types should inherit from this class and implement the `execute` method. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): """Initialize the node. Args: @@ -41,6 +41,7 @@ class BaseNode(ABC): self.node_type = node_config["type"] self.cycle = node_config.get("cycle") self.node_name = node_config.get("name", self.node_id) + self.down_stream_nodes = down_stream_nodes # 使用 or 运算符处理 None 值 self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} @@ -93,18 +94,16 @@ class BaseNode(ABC): dict: A dict with a single key 'activate', mapping node IDs to their activation status (True/False). """ - edges = self.workflow_config.get("edges") - under_stream_nodes = [ - edge.get("target") - for edge in edges - if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES - ] - return { - "activate": { - node_id: self.check_activate(state) - for node_id in under_stream_nodes - } | {self.node_id: self.check_activate(state)} - } + activate_flag = self.check_activate(state) + + if self.node_type not in BRANCH_NODES: + activate = {node_id: activate_flag for node_id in self.down_stream_nodes} + else: + activate = {} + + activate[self.node_id] = activate_flag + + return {"activate": activate} @abstractmethod async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -315,8 +314,8 @@ class BaseNode(ABC): elapsed_time = (time.time() - start_time) * 1000 - logger.info(f"Node {self.node_id} streaming execution finished, " - f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") + logger.debug(f"Node {self.node_id} streaming execution finished, " + f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) @@ -428,8 +427,8 @@ class BaseNode(ABC): when an error edge exists. If no error edge exists, this method raises an exception to stop the workflow. """ - # Check if the node has an error edge defined - error_edge = self._find_error_edge() + # # Check if the node has an error edge defined + # error_edge = self._find_error_edge() # Extract input data (for logging or audit purposes) input_data = self._extract_input(state, variable_pool) @@ -447,27 +446,26 @@ class BaseNode(ABC): "error": error_message } - if error_edge: - # If an error edge exists, log a warning and continue to error node - logger.warning( - f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" - ) - return { - "node_outputs": { - self.node_id: node_output - }, - "error": error_message, - "error_node": self.node_id - } - else: - # If no error edge, send the error via stream writer and stop the workflow - writer = get_stream_writer() - writer({ - "type": "node_error", - **node_output - }) - logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") - raise Exception(f"Node {self.node_id} execution failed: {error_message}") + # if error_edge: + # # If an error edge exists, log a warning and continue to error node + # logger.warning( + # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" + # ) + # return { + # "node_outputs": { + # self.node_id: node_output + # }, + # "error": error_message, + # "error_node": self.node_id + # } + # else: + writer = get_stream_writer() + writer({ + "type": "node_error", + **node_output + }) + logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") + raise Exception(f"Node {self.node_id} execution failed: {error_message}") def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """Extracts the input data for this node (used for logging or audit). @@ -623,7 +621,6 @@ class BaseNode(ABC): async def process_message( api_config: ModelInfo, content: str | dict | FileObject, - end_user_id: str, enable_file=False ) -> list | str | None: provider = api_config.provider @@ -642,10 +639,10 @@ class BaseNode(ABC): return content elif isinstance(content, FileObject): - if content.content_cache.get(provider): - return content.content_cache[provider] + if content.content_cache.get(f"{provider}_{api_config.is_omni}"): + return content.content_cache[f"{provider}_{api_config.is_omni}"] with get_db_read() as db: - multimodel_service = MultimodalService(db, api_config=api_config) + multimodal_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( type=content.type, url=content.url, @@ -654,16 +651,15 @@ class BaseNode(ABC): upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, ) file_obj.set_content(content.get_content()) - message = await multimodel_service.process_files( - end_user_id, + message = await multimodal_service.process_files( [file_obj], ) content.set_content(file_obj.get_content()) if message: - content.content_cache[provider] = message + content.content_cache[f"{provider}_{api_config.is_omni}"] = message return message return None - raise TypeError(f'Unexpect input value type - {type(content)}') + raise TypeError(f'Unexpected input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 1e055002..d89b208b 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -51,8 +51,8 @@ console.log(result) class CodeNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: CodeNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 71e0dbdb..fc80939f 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode): It acts as a container and execution controller for a subgraph. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) - - self.cycle_nodes = list() # Nodes belonging to this cycle - self.cycle_edges = list() # Edges connecting nodes within the cycle + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None self.child_variable_pool: VariablePool | None = None - self.build_graph() - self.iteration_flag = True def _output_types(self) -> dict[str, VariableType]: outputs = {"__child_state": VariableType.ARRAY_OBJECT} @@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode): else: remain_edges.append(edge) - # Update workflow_config by removing cycle nodes and internal edges - self.workflow_config["nodes"] = [ - node for node in nodes if node.get("cycle") != self.node_id - ] - self.workflow_config["edges"] = remain_edges + # # Update workflow_config by removing cycle nodes and internal edges + # self.workflow_config["nodes"] = [ + # node for node in nodes if node.get("cycle") != self.node_id + # ] + # self.workflow_config["edges"] = remain_edges return cycle_nodes, cycle_edges @@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode): 3. Compile the graph for runtime execution """ from app.core.workflow.engine.graph_builder import GraphBuilder - self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { "nodes": self.cycle_nodes, "edges": self.cycle_edges, }, - subgraph=True, - variable_pool=self.child_variable_pool + variable_pool=self.child_variable_pool, + cycle=self.node_id ) - self.start_node_id = builder.start_node_id self.graph = builder.build() + self.start_node_id = builder.start_node_id self.child_variable_pool = builder.variable_pool async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ + self.build_graph() if self.node_type == NodeType.LOOP: return await LoopRuntime( start_id=self.start_node_id, @@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode): raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + self.build_graph() if self.node_type == NodeType.LOOP: yield { "__final__": True, diff --git a/api/app/core/workflow/nodes/document_extractor/__init__.py b/api/app/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 00000000..c51bc2c0 --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .config import DocExtractorNodeConfig +from .node import DocExtractorNode + +__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"] diff --git a/api/app/core/workflow/nodes/document_extractor/config.py b/api/app/core/workflow/nodes/document_extractor/config.py new file mode 100644 index 00000000..69f7f76d --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/config.py @@ -0,0 +1,18 @@ +from pydantic import Field +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class DocExtractorNodeConfig(BaseNodeConfig): + file_selector: str = Field( + ..., + description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}" + ) + + class Config: + json_schema_extra = { + "examples": [ + { + "file_selector": "{{ sys.files }}" + } + ] + } diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 00000000..40641f3c --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,103 @@ +import logging +from typing import Any + +from app.core.workflow.engine.state_manager import WorkflowState +from app.core.workflow.engine.variable_pool import VariablePool +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.schemas.app_schema import FileInput, FileType, TransferMethod + +logger = logging.getLogger(__name__) + + +def _file_object_to_file_input(f: FileObject) -> FileInput: + """Convert workflow FileObject to multimodal FileInput.""" + return FileInput( + type=FileType.DOCUMENT, + transfer_method=TransferMethod(f.transfer_method), + url=f.url or None, + upload_file_id=f.file_id or None, + file_type=f.origin_file_type or "", + ) + + +def _normalise_files(val: Any) -> list[FileObject]: + if isinstance(val, FileObject): + return [val] + if isinstance(val, dict) and val.get("is_file"): + return [FileObject(**val)] + if isinstance(val, list): + result: list[FileObject] = [] + for item in val: + if isinstance(item, FileObject): + result.append(item) + elif isinstance(item, dict) and item.get("is_file"): + result.append(FileObject(**item)) + else: + logger.warning("Ignoring non-file entry in file list for document extractor: %r", item) + return result + return [] + + +class DocExtractorNode(BaseNode): + """Document Extractor Node. + + Reads one or more file variables and extracts their text content + by delegating to MultimodalService._extract_document_text. + + Outputs: + text (string) – full concatenated text of all input files + chunks (array[string]) – per-file extracted text + """ + + def _output_types(self) -> dict[str, VariableType]: + return { + "text": VariableType.STRING, + "chunks": VariableType.ARRAY_STRING, + } + + def _extract_output(self, business_result: Any) -> Any: + return business_result + + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return {"file_selector": self.config.get("file_selector")} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: + config = DocExtractorNodeConfig(**self.config) + + raw_val = self.get_variable(config.file_selector, variable_pool, strict=False) + if raw_val is None: + logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty") + return {"text": "", "chunks": []} + + files = _normalise_files(raw_val) + if not files: + return {"text": "", "chunks": []} + + chunks: list[str] = [] + with get_db_read() as db: + from app.services.multimodal_service import MultimodalService + svc = MultimodalService(db) + for f in files: + try: + file_input = _file_object_to_file_input(f) + # Ensure URL is populated for local files + if not file_input.url: + file_input.url = await svc.get_file_url(file_input) + # Reuse cached bytes if already fetched + if f.get_content(): + file_input.set_content(f.get_content()) + text = await svc._extract_document_text(file_input) + chunks.append(text) + except Exception as e: + logger.error( + f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}", + exc_info=True, + ) + chunks.append("") + + full_text = "\n\n".join(c for c in chunks if c) + logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}") + return {"text": full_text, "chunks": chunks} diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index 5c2a6c2a..02df5091 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -1,9 +1,7 @@ """End 节点配置""" - from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig class EndNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2799316a..770cf328 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -36,8 +36,6 @@ class EndNode(BaseNode): Returns: 最终输出字符串 """ - logger.info(f"节点 {self.node_id} (End) 开始执行") - # 获取配置的输出模板 output_template = self.config.get("output") @@ -46,11 +44,4 @@ class EndNode(BaseNode): output = self._render_template(output_template, variable_pool, strict=False) else: output = "" - - # 统计信息(用于日志) - node_outputs = state.get("node_outputs", {}) - total_nodes = len(node_outputs) - - logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - return output diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 43ab593b..529cd0b3 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -23,12 +23,13 @@ class NodeType(StrEnum): BREAK = "break" MEMORY_READ = "memory-read" MEMORY_WRITE = "memory-write" + DOCUMENT_EXTRACTOR = "document-extractor" UNKNOWN = "unknown" NOTES = "notes" -BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] +BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index fe38fafb..e1b84f0c 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel): ) -class HttpErrorDefaultTamplete(BaseModel): +class HttpErrorDefaultTemplate(BaseModel): body: str = Field( default="", description="Default body returned on HTTP error", @@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel): description="Error handling strategy: 'none', 'default', or 'branch'", ) - default: HttpErrorDefaultTamplete | None = Field( + default: HttpErrorDefaultTemplate | None = Field( default=None, description="Default response template for error handling", ) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 23378c83..086bee4a 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput -from app.core.workflow.utils.file_processer import mime_to_file_type +from app.core.workflow.utils.file_processor import mime_to_file_type from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.schemas import FileType, TransferMethod @@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode): or a branch identifier string when error branching is enabled. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: HttpRequestNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 5d2bdf9a..ec46b20b 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class IfElseNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: IfElseNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index e13709d4..abf21524 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: JinjaRenderNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index d3e9efd9..92699cb4 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.vector_service: ElasticSearchVector | None = None diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index b293d1f4..a691001f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -70,8 +70,8 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: LLMNodeConfig | None = None self.messages = [] @@ -144,7 +144,6 @@ class LLMNode(BaseNode): f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages - if messages_config: # 使用 LangChain 消息格式 messages = [] @@ -153,7 +152,6 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ @@ -161,32 +159,31 @@ class LLMNode(BaseNode): "content": await self.process_message( model_info, content, - user_id, self.typed_config.vision, ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -200,7 +197,7 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(model_info, file, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( @@ -210,7 +207,6 @@ class LLMNode(BaseNode): message["content"] = await self.process_message( model_info, message["content"], - user_id, self.typed_config.vision ) history_message.append(message) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 1d42e82e..73c52b79 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -1,3 +1,4 @@ +import re from typing import Any from app.core.workflow.engine.state_manager import WorkflowState @@ -5,14 +6,16 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read +from app.schemas import FileInput from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task class MemoryReadNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryReadNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: @@ -36,19 +39,32 @@ class MemoryReadNode(BaseNode): search_switch=self.typed_config.search_switch, history=[], db=db, - storage_type="neo4j", - user_rag_memory_id="" + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) class MemoryWriteNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryWriteNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} + @staticmethod + def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]: + variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' + variable_pattern = re.compile(variable_pattern_string) + variables = variable_pattern.findall(content) + file_variables = [] + for variable in variables: + if variable_pool.is_file_variable(variable): + file_variables.append(variable) + for var in file_variables: + content = content.replace(var, "") + return file_variables, content + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", variable_pool) @@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode): }) for message in self.typed_config.messages: + file_variables, content = self._extract_multimodal_memory_variables( + message.content, + variable_pool + ) + file_info = [] + for var in file_variables: + instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var) + if isinstance(instence, FileVariable): + file_info.append(FileInput( + type=instence.value.type, + transfer_method=instence.value.transfer_method, + upload_file_id=instence.value.file_id, + url=instence.value.url, + file_type=instence.value.origin_file_type + ).model_dump()) + elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable: + for file_instence in instence.value: + file_info.append(FileInput( + type=file_instence.value.type, + transfer_method=file_instence.value.transfer_method, + upload_file_id=file_instence.value.file_id, + url=file_instence.value.url, + file_type=file_instence.value.origin_file_type + ).model_dump()) messages.append({ "role": message.role, - "content": self._render_template(message.content, variable_pool) + "content": self._render_template(content, variable_pool), + "files": file_info }) write_message_task.delay( - end_user_id, - messages, - str(self.typed_config.config_id), - "neo4j", - "" + end_user_id=end_user_id, + message=messages, + config_id=str(self.typed_config.config_id), + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) return "success" diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 864e3251..49add867 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.tool import ToolNode +from app.core.workflow.nodes.document_extractor import DocExtractorNode logger = logging.getLogger(__name__) @@ -49,7 +50,8 @@ WorkflowNode = Union[ ToolNode, MemoryReadNode, MemoryWriteNode, - CodeNode + CodeNode, + DocExtractorNode ] @@ -81,6 +83,7 @@ class NodeFactory: NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.CODE: CodeNode, + NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode } @classmethod @@ -104,13 +107,15 @@ class NodeFactory: def create_node( cls, node_config: dict[str, Any], - workflow_config: dict[str, Any] + workflow_config: dict[str, Any], + down_stream_nodes: list[str] ) -> WorkflowNode | None: """创建节点实例 Args: node_config: 节点配置 workflow_config: 工作流配置 + down_stream_nodes: 下游节点 Returns: 节点实例或 None(对于不支持的节点类型) @@ -127,7 +132,7 @@ class NodeFactory: # 创建节点实例 logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") - return node_class(node_config, workflow_config) + return node_class(node_config, workflow_config, down_stream_nodes) @classmethod def get_supported_types(cls) -> list[str]: diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index acac09e4..3dc5fcc3 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class ParameterExtractorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ParameterExtractorNodeConfig | None = None self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 5cebd886..31fadaf6 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1" class QuestionClassifierNode(BaseNode): """问题分类器节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index a9618f7b..7a324cc4 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -27,14 +27,8 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - """初始化 Start 节点 - - Args: - node_config: 节点配置 - workflow_config: 工作流配置 - """ - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) # 解析并验证配置 self.typed_config: StartNodeConfig | None = None @@ -62,7 +56,6 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ self.typed_config = StartNodeConfig(**self.config) - logger.info(f"节点 {self.node_id} (Start) 开始执行") # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(variable_pool) @@ -77,9 +70,9 @@ class StartNode(BaseNode): **custom_vars # 自定义变量作为节点输出的一部分 } - logger.info( - f"节点 {self.node_id} (Start) 执行完成," - f"输出了 {len(custom_vars)} 个自定义变量" + logger.debug( + f"Node {self.node_id} (Start) execution completed, " + f"outputting {len(custom_vars)} custom variables" ) return result diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 0e9d3c62..72c5c6a8 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}") class ToolNode(BaseNode): """工具节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ToolNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index de82f8ff..9a9c5566 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class VariableAggregatorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: VariableAggregatorNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/utils/file_processer.py b/api/app/core/workflow/utils/file_processor.py similarity index 100% rename from api/app/core/workflow/utils/file_processer.py rename to api/app/core/workflow/utils/file_processor.py diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 424fdf20..6a73efc4 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -153,7 +153,8 @@ class TemplateRenderer: # 全局渲染器实例(严格模式) -_default_renderer = TemplateRenderer(strict=True) +_strict_renderer = TemplateRenderer(strict=True) +_lenient_renderer = TemplateRenderer(strict=False) def render_template( @@ -184,7 +185,7 @@ def render_template( ... ) '请分析: 这是一段文本' """ - renderer = TemplateRenderer(strict=strict) + renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) @@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]: Returns: 错误列表 """ - return _default_renderer.validate(template) + return _strict_renderer.validate(template) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index fe4aea19..0ad74865 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -6,6 +6,7 @@ import copy import logging +from collections import defaultdict, deque from typing import Any, Union, TYPE_CHECKING from app.core.workflow.nodes.enums import NodeType @@ -119,7 +120,6 @@ class WorkflowValidator: errors = [] graphs = cls.get_subgraph(workflow_config) - logger.info(graphs) for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) @@ -183,7 +183,7 @@ class WorkflowValidator: has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) if has_cycle: errors.append( - f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}" ) # 8. 验证变量名 @@ -204,18 +204,18 @@ class WorkflowValidator: Returns: 可达节点 ID 集合 """ + adj = defaultdict(list) + for edge in edges: + adj[edge["source"]].append(edge["target"]) + reachable = {start_id} - queue = [start_id] - + queue = deque([start_id]) while queue: - current = queue.pop(0) - for edge in edges: - if edge.get("source") == current: - target = edge.get("target") - if target and target not in reachable: - reachable.add(target) - queue.append(target) - + current = queue.popleft() + for target in adj[current]: + if target not in reachable: + reachable.add(target) + queue.append(target) return reachable @staticmethod @@ -229,10 +229,6 @@ class WorkflowValidator: Returns: (has_cycle, cycle_path): 是否有循环和循环路径 """ - # 排除 loop 类型的节点 - loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} - - # 构建邻接表(排除 loop 节点的边和错误边) graph: dict[str, list[str]] = {} for edge in edges: source = edge.get("source") @@ -243,10 +239,6 @@ class WorkflowValidator: if edge_type == "error": continue - # 如果涉及 loop 节点,跳过 - if source in loop_nodes or target in loop_nodes: - continue - if source and target: if source not in graph: graph[source] = [] diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 5e8e3f1e..79e023c1 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -54,7 +54,7 @@ class DictVariable(BaseVariable): def valid_value(self, value) -> dict: if not isinstance(value, dict): - raise TypeError(f"Value must be a dict - {type(value)}:{value}") + raise TypeError(f"Value must be a dict - {type(value)}:{value}") return value def to_literal(self) -> str: diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 1095a386..616f7f3a 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -30,6 +30,9 @@ class MemoryConfig(Base): llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") + vision_id = Column(String, nullable=True, comment="视觉模型配置ID") + audio_id = Column(String, nullable=True, comment="语音模型配置ID") + video_id = Column(String, nullable=True, comment="视频模型配置ID") # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcef..69bedc3d 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,10 +2,11 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text +from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY from sqlalchemy.orm import relationship from sqlalchemy.sql import func + from app.db import Base @@ -26,9 +27,9 @@ class ModelType(StrEnum): RERANK = "rerank" # TTS = "tts" # SPEECH2TEXT = "speech2text" - # IMAGE = "image" + IMAGE = "image" # AUDIO = "audio" - # VISION = "vision" + VIDEO = "video" class ModelProvider(StrEnum): @@ -45,6 +46,7 @@ class ModelProvider(StrEnum): XINFERENCE = "xinference" GPUSTACK = "gpustack" BEDROCK = "bedrock" + VOLCANO = "volcano" COMPOSITE = "composite" diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 044857d2..8f101eb5 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -23,6 +23,21 @@ class Tenants(Base): # 国际化语言配置字段 default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 + + # 租户联系信息 + contact_name = Column(String(100), nullable=True) # 联系人姓名 + contact_email = Column(String(255), nullable=True) # 联系人邮箱 + contact_phone = Column(String(50), nullable=True) # 联系人电话 + + # 租户套餐信息 + plan = Column(String(50), nullable=True) # 套餐类型 + plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间 + api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 + status = Column(String(50), nullable=True, default='active') # 租户状态 + + # 租户功能开关字段 + feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单") + feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单") # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index b6de28ec..81319789 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -9,7 +9,7 @@ class User(Base): __tablename__ = "users" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - username = Column(String, unique=True, index=True, nullable=False) + username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一 email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py index bcb3b622..6d74bcaf 100644 --- a/api/app/repositories/home_page_repository.py +++ b/api/app/repositories/home_page_repository.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import func from uuid import UUID -from typing import Dict +from typing import Dict, Optional, Any from app.models.end_user_model import EndUser from app.models.user_model import User @@ -190,4 +190,63 @@ class HomePageRepository: user_count_dict = {workspace_id: count for workspace_id, count in user_counts} - return workspaces, app_count_dict, user_count_dict \ No newline at end of file + return workspaces, app_count_dict, user_count_dict + + @staticmethod + def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]: + """ + 从数据库获取版本说明(优先读取已发布的版本) + 使用反射方式读取表结构,不依赖 premium 模型类 + + Args: + db: 数据库会话 + version: 版本号,如 "v0.2.7" + + Returns: + 版本说明字典,格式与 version_info.json 一致 + 如果数据库中没有该版本,返回 None + """ + try: + from sqlalchemy import Table, MetaData + + metadata = MetaData() + version_notes = Table('version_notes', metadata, autoload_with=db.engine) + version_note_items = Table('version_note_items', metadata, autoload_with=db.engine) + + note = db.query(version_notes).filter( + version_notes.c.version == version, + version_notes.c.is_published == True + ).first() + + if not note: + return None + + items = db.query(version_note_items).filter( + version_note_items.c.note_id == note.id + ).order_by(version_note_items.c.sort_order).all() + + core_upgrades = [] + for item in items: + title = item.title + content = item.content + if content: + core_upgrades.append(f"{title}
{content}") + else: + core_upgrades.append(title) + + return { + "introduction": { + "codeName": "", + "releaseDate": note.release_date.isoformat() if note.release_date else "", + "upgradePosition": "", + "coreUpgrades": core_upgrades + }, + "introduction_en": { + "codeName": "", + "releaseDate": note.release_date.isoformat() if note.release_date else "", + "upgradePosition": "", + "coreUpgrades": core_upgrades + } + } + except Exception: + return None \ No newline at end of file diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22f13449..e64d19a3 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -9,21 +9,22 @@ Classes: """ import uuid -from uuid import UUID from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session + from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.memory_config_model import MemoryConfig +from app.models.workspace_model import Workspace from app.schemas.memory_storage_schema import ( - ConfigKey, ConfigParamsCreate, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc, select -from sqlalchemy.orm import Session - from app.utils.config_utils import resolve_config_id # 获取数据库专用日志器 @@ -157,7 +158,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -309,57 +310,21 @@ class MemoryConfigRepository: Returns: Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None - - Raises: - ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") try: - db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() + stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id) + db_config = db.execute(stmt).scalar_one_or_none() if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - # 更新字段映射 - field_mapping = { - # 模型选择 - "llm_id": "llm_id", - "embedding_id": "embedding_id", - "rerank_id": "rerank_id", - # 记忆萃取引擎 - "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise", - "enable_llm_disambiguation": "enable_llm_disambiguation", - "deep_retrieval": "deep_retrieval", - "t_type_strict": "t_type_strict", - "t_name_strict": "t_name_strict", - "t_overall": "t_overall", - "state": "state", - "chunker_strategy": "chunker_strategy", - # 句子提取 - "statement_granularity": "statement_granularity", - "include_dialogue_context": "include_dialogue_context", - "max_context": "max_context", - # 剪枝配置 - "pruning_enabled": "pruning_enabled", - "pruning_scene": "pruning_scene", - "pruning_threshold": "pruning_threshold", - # 自我反思配置 - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - } + update_data = update.model_dump(exclude_unset=True) + update_data.pop("config_id", None) - has_update = False - for api_field, db_field in field_mapping.items(): - value = getattr(update, api_field, None) - if value is not None: - setattr(db_config, db_field, value) - has_update = True - - if not has_update: - raise ValueError("No fields to update") + for field, value in update_data.items(): + setattr(db_config, field, value) db.commit() db.refresh(db_config) @@ -443,6 +408,9 @@ class MemoryConfigRepository: "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, "rerank_id": db_config.rerank_id, + "vision_id": db_config.vision_id, + "audio_id": db_config.audio_id, + "video_id": db_config.video_id, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_disambiguation": db_config.enable_llm_disambiguation, "deep_retrieval": db_config.deep_retrieval, @@ -527,7 +495,10 @@ class MemoryConfigRepository: raise @staticmethod - def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: + def get_config_with_workspace( + db: Session, + config_id: uuid.UUID | int | str + ) -> Optional[tuple[MemoryConfig, Workspace]]: """Get memory config and its associated workspace information Args: @@ -542,8 +513,6 @@ class MemoryConfigRepository: """ import time - from app.models.workspace_model import Workspace - start_time = time.time() config_id = resolve_config_id(config_id, db) @@ -630,7 +599,7 @@ class MemoryConfigRepository: db_logger.debug( f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") - return (config, workspace) + return config, workspace except ValueError: # Re-raise known business exceptions @@ -666,7 +635,7 @@ class MemoryConfigRepository: List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ from app.models.ontology_scene import OntologyScene - + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: @@ -730,7 +699,7 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 默认配置对象,不存在则返回None """ db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}") - + try: # 优先查找显式标记为默认的配置 stmt = ( @@ -742,13 +711,13 @@ class MemoryConfigRepository: ) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"找到默认配置: config_id={config.config_id}") return config - + # 回退:获取最早创建的活跃配置 stmt = ( select(MemoryConfig) @@ -759,25 +728,25 @@ class MemoryConfigRepository: .order_by(MemoryConfig.created_at.asc()) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}") else: db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}") - + return config - + except Exception as e: db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}") raise @staticmethod def get_with_fallback( - db: Session, - config_id: Optional[uuid.UUID], - workspace_id: uuid.UUID + db: Session, + config_id: Optional[uuid.UUID], + workspace_id: uuid.UUID ) -> Optional[MemoryConfig]: """获取记忆配置,支持回退到工作空间默认配置 @@ -792,19 +761,18 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 配置对象,如果都不存在则返回None """ db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}") - + if not config_id: db_logger.debug("config_id 为空,使用工作空间默认配置") return MemoryConfigRepository.get_workspace_default(db, workspace_id) - + config = db.get(MemoryConfig, config_id) - + if config: return config - + db_logger.warning( f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}" ) - - return MemoryConfigRepository.get_workspace_default(db, workspace_id) + return MemoryConfigRepository.get_workspace_default(db, workspace_id) diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f49227d3..8c477d39 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -1,14 +1,15 @@ -from sqlalchemy.orm import Session, joinedload, selectinload -from sqlalchemy import and_, or_, func, desc, select -from typing import List, Optional, Dict, Any, Tuple import uuid +from typing import List, Optional, Dict, Any, Tuple +from sqlalchemy import and_, or_, func, desc +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.schemas.model_schema import ( ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigQuery, ModelConfigQueryNew ) -from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() @@ -137,6 +138,9 @@ class ModelConfigRepository: type_values.append(ModelType.LLM) filters.append(ModelConfig.type.in_(type_values)) + if query.capability: + filters.append(ModelConfig.capability.contains(query.capability)) + if query.is_active is not None: filters.append(ModelConfig.is_active == query.is_active) @@ -435,7 +439,6 @@ class ModelConfigRepository: ModelConfig.is_public ), ModelConfig.provider == provider, - ModelConfig.is_active, ~ModelConfig.is_composite ) ).all() diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..1939a062 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,17 +1,22 @@ +import logging from typing import List, Optional -from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode +from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ + MEMORY_SUMMARY_NODE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = logging.getLogger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") - print(f"All end_user_id: {end_user_id} node and edge deleted successfully") + logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result + async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. @@ -23,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn List of created node UUIDs or None if failed """ if not dialogues: - print("No dialogues to save") + logger.info("No dialogues to save") return [] try: @@ -48,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") + logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: - print(f"Error creating dialogue nodes: {e}") + logger.error(f"Error creating dialogue nodes: {e}") return None @@ -67,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC List of created node UUIDs or None if failed """ if not statements: - print("No statements to save") + logger.info("No statements to save") return [] try: @@ -120,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} statement nodes") + logger.info(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: - print(f"Error creating statement nodes: {e}") + logger.error(f"Error creating statement nodes: {e}") return None + async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. @@ -138,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> List of created chunk UUIDs or None if failed """ if not chunks: - print("No chunk nodes to add") + logger.info("No chunk nodes to add") return [] try: @@ -171,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} chunk nodes") + logger.info(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: - print(f"Error creating chunk nodes: {e}") + logger.error(f"Error creating chunk nodes: {e}") return None - -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: +async def add_memory_summary_nodes( + summaries: List[MemorySummaryNode], + connector: Neo4jConnector +) -> Optional[List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -191,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector List of created summary node ids or None if failed """ if not summaries: - print("No memory summary nodes to add") + logger.info("No memory summary nodes to add") return [] try: @@ -211,16 +219,14 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) - + result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") + logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None - - diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 7273340e..bd448c99 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -300,7 +300,7 @@ class CommunityRepository: ) return bool(result) except Exception as e: - logger.error(f"update_community_metadata failed: {e}") + logger.error(f"update_community_metadata failed: {e}", exc_info=True) return False async def batch_update_community_metadata( diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0cdaeb59..1f699ad8 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id, RETURN elementId(r) AS uuid """ - # Entity Merge Query MERGE_ENTITIES = """ MATCH (canonical:ExtractedEntity {id: $canonical_id}) @@ -829,9 +828,8 @@ neo4j_query_all = """ other as entity2 """ - '''针对当前节点下扩长的句子,实体和总结''' -Memory_Timeline_ExtractedEntity=""" +Memory_Timeline_ExtractedEntity = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) = $id AND (ms:ExtractedEntity OR ms:MemorySummary) @@ -869,7 +867,7 @@ RETURN """ -Memory_Timeline_MemorySummary=""" +Memory_Timeline_MemorySummary = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) =$id AND (ms:MemorySummary OR ms:ExtractedEntity) @@ -904,7 +902,7 @@ RETURN } ) AS statement; """ -Memory_Timeline_Statement=""" +Memory_Timeline_Statement = """ MATCH (n) WHERE elementId(n) = $id @@ -947,7 +945,7 @@ RETURN """ '''针对当前节点,主要获取更加完整的句子节点''' -Memory_Space_Emotion_Statement=""" +Memory_Space_Emotion_Statement = """ MATCH (n) WHERE elementId(n) = $id RETURN @@ -957,7 +955,7 @@ RETURN n.statement AS statement; """ -Memory_Space_Emotion_MemorySummary=""" +Memory_Space_Emotion_MemorySummary = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -970,7 +968,7 @@ RETURN DISTINCT e.emotion_type AS emotion_type, e.statement AS statement; """ -Memory_Space_Emotion_ExtractedEntity=""" +Memory_Space_Emotion_ExtractedEntity = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -985,18 +983,18 @@ RETURN DISTINCT '''获取实体''' -Memory_Space_User=""" +Memory_Space_User = """ MATCH (n)-[r]->(m) WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ -Memory_Space_Entity=""" +Memory_Space_Entity = """ MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN DISTINCT m.name as name,m.end_user_id as end_user_id """ -Memory_Space_Associative=""" +Memory_Space_Associative = """ MATCH (u)-[]-(x)-[]-(h) WHERE elementId(u) = $user_id AND elementId(h) = $id @@ -1005,61 +1003,69 @@ RETURN DISTINCT """ Graph_Node_query = """ - MATCH (n:MemorySummary) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 0 AS priority - LIMIT $limit +MATCH (n:MemorySummary) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Dialogue) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT 1 +MATCH (n:Dialogue) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT 1 - UNION ALL +UNION ALL - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT $limit +MATCH (n:Statement) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:ExtractedEntity) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 2 AS priority - LIMIT $limit +MATCH (n:ExtractedEntity) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Chunk) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 3 AS priority - LIMIT $limit +MATCH (n:Chunk) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority +LIMIT $limit - """ +UNION ALL +MATCH (n:Perceptual) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 4 AS priority +""" # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 @@ -1069,6 +1075,7 @@ Graph_Node_query = """ COMMUNITY_NODE_UPSERT = """ MERGE (c:Community {community_id: $community_id}) +ON CREATE SET c.id = $community_id SET c.end_user_id = $end_user_id, c.member_count = $member_count, c.updated_at = datetime() @@ -1175,7 +1182,8 @@ RETURN c.community_id AS community_id, cnt AS member_count UPDATE_COMMUNITY_METADATA = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -SET c.name = $name, +SET c.id = coalesce(c.id, $community_id), + c.name = $name, c.summary = $summary, c.core_entities = $core_entities, c.summary_embedding = $summary_embedding, @@ -1186,7 +1194,8 @@ RETURN c.community_id AS community_id BATCH_UPDATE_COMMUNITY_METADATA = """ UNWIND $communities AS row MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) -SET c.name = row.name, +SET c.id = coalesce(c.id, row.community_id), + c.name = row.name, c.summary = row.summary, c.core_entities = row.core_entities, c.summary_embedding = row.summary_embedding, @@ -1270,6 +1279,40 @@ RETURN startNode(r) = e AS r_from_e """ +CHECK_COMMUNITY_IS_COMPLETE = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL +) AS is_complete +""" + +CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL AND + c.summary_embedding IS NOT NULL +) AS is_complete +""" + +GET_INCOMPLETE_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL + OR c.name = '' OR c.summary = '' +RETURN c.community_id AS community_id +""" + +GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.name = '' + OR c.summary IS NULL OR c.summary = '' + OR c.core_entities IS NULL + OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') +RETURN c.community_id AS community_id +""" # Community keyword search: matches name or summary via fulltext index SEARCH_COMMUNITIES_BY_KEYWORD = """ @@ -1327,37 +1370,35 @@ ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit """ -CHECK_COMMUNITY_IS_COMPLETE = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL -) AS is_complete +# 感知记忆节点保存 +PERCEPTUAL_NODE_SAVE = """ +UNWIND $perceptuals AS p +MERGE (n:Perceptual {id: p.id}) +SET n += { + id: p.id, + end_user_id: p.end_user_id, + perceptual_type: p.perceptual_type, + file_path: p.file_path, + file_name: p.file_name, + file_ext: p.file_ext, + summary: p.summary, + keywords: p.keywords, + topic: p.topic, + domain: p.domain, + created_at: p.created_at, + file_type: p.file_type, + summary_embedding: p.summary_embedding +} +RETURN n.id AS uuid """ -CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL AND - c.summary_embedding IS NOT NULL -) AS is_complete -""" - -GET_INCOMPLETE_COMMUNITIES = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL - OR c.name = '' OR c.summary = '' -RETURN c.community_id AS community_id -""" - -GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.name = '' - OR c.summary IS NULL OR c.summary = '' - OR c.core_entities IS NULL - OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') -RETURN c.community_id AS community_id +# 感知记忆与对话的关联边 +PERCEPTUAL_CHUNK_EDGE_SAVE = """ +UNWIND $edges AS edge +MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) +MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id}) +MERGE (c)-[r:HAS_PERCEPTUAL]->(p) +ON CREATE SET r.end_user_id = edge.end_user_id, + r.created_at = edge.created_at +RETURN elementId(r) AS uuid """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..adc266fe 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import ( StatementNode, ExtractedEntityNode, EntityEntityEdge, + PerceptualNode, + PerceptualEdge, ) import logging + logger = logging.getLogger(__name__) + + async def save_entities_and_relationships( - entity_nodes: List[ExtractedEntityNode], - entity_entity_edges: List[EntityEntityEdge], - connector: Neo4jConnector + entity_nodes: List[ExtractedEntityNode], + entity_entity_edges: List[EntityEntityEdge], + connector: Neo4jConnector ): """Save entities and their relationships using graph models""" all_entities = [entity.model_dump() for entity in entity_nodes] @@ -73,8 +78,8 @@ async def save_entities_and_relationships( async def save_chunk_nodes( - chunk_nodes: List[ChunkNode], - connector: Neo4jConnector + chunk_nodes: List[ChunkNode], + connector: Neo4jConnector ): """Save chunk nodes using graph models""" if not chunk_nodes: @@ -89,8 +94,8 @@ async def save_chunk_nodes( async def save_statement_chunk_edges( - statement_chunk_edges: List[StatementChunkEdge], - connector: Neo4jConnector + statement_chunk_edges: List[StatementChunkEdge], + connector: Neo4jConnector ): """Save statement-chunk edges using graph models""" if not statement_chunk_edges: @@ -118,8 +123,8 @@ async def save_statement_chunk_edges( async def save_statement_entity_edges( - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ): """Save statement-entity edges using graph models""" if not statement_entity_edges: @@ -142,7 +147,7 @@ async def save_statement_entity_edges( if all_se_edges: try: await connector.execute_query( - STATEMENT_ENTITY_EDGE_SAVE, + STATEMENT_ENTITY_EDGE_SAVE, relationships=all_se_edges ) except Exception: @@ -154,24 +159,28 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List[ChunkNode], statement_nodes: List[StatementNode], entity_nodes: List[ExtractedEntityNode], + perceptual_nodes: List[PerceptualNode], entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], + perceptual_edges: List[PerceptualEdge], connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 - schedule_clustering_after_write() 显式触发。 + _trigger_clustering_sync() 显式触发。 Args: dialogue_nodes: List of DialogueNode objects to save chunk_nodes: List of ChunkNode objects to save statement_nodes: List of StatementNode objects to save entity_nodes: List of ExtractedEntityNode objects to save + perceptual_nodes: List of PerceptualNode objects to save entity_edges: List of EntityEntityEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save + perceptual_edges: List of PerceptualEdge objects to save connector: Neo4j connector instance Returns: @@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) dialogue_uuids = [record["uuid"] async for record in result] results['dialogues'] = dialogue_uuids - print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") + logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") # 2. Save all chunk nodes in batch if chunk_nodes: @@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j( results['chunks'] = chunk_uuids logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") + if perceptual_nodes: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE + perceptual_data = [node.model_dump() for node in perceptual_nodes] + result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data) + perceptual_uuids = [record["uuid"] async for record in result] + results["perceptuals"] = perceptual_uuids + logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j") + # 3. Save all statement nodes in batch if statement_nodes: from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE @@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j( results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + if perceptual_edges: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE + perceptual_edge_data = [] + for edge in perceptual_edges: + print(edge.source, edge.target) + perceptual_edge_data.append({ + "perceptual_id": edge.source, + "chunk_id": edge.target, + "end_user_id": edge.end_user_id, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + }) + result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data) + perceptual_edges_uuids = [record["uuid"] async for record in result] + results['perceptual_chunk_edges'] = perceptual_edges_uuids + logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j") + return results try: @@ -303,16 +336,13 @@ async def save_dialog_and_statements_to_neo4j( return False -def schedule_clustering_after_write( - entity_nodes: List, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, +async def _trigger_clustering_sync( + entity_nodes: List, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ - 写入 Neo4j 成功后,调度后台聚类任务。 - - 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 - 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + 同步等待聚类完成,避免与其他 LLM 任务并发冲突。 """ if not entity_nodes: return @@ -324,15 +354,16 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] - logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id) async def _trigger_clustering( - new_entity_ids: List[str], - end_user_id: str, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + new_entity_ids: List[str], + end_user_id: str, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 1582d862..e34945eb 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -196,6 +196,13 @@ class CitationConfig(BaseModel): enabled: bool = Field(default=False) +class Citation(BaseModel): + document_id: str + file_name: str + knowledge_id: str + score: float + + class WebSearchConfig(BaseModel): """联网搜索配置""" enabled: bool = Field(default=False) diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 8d7490fe..e186e54b 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -387,6 +387,12 @@ class MemoryConfig: rerank_model_id: Optional[UUID] = None rerank_model_name: Optional[str] = None + video_model_id: Optional[UUID] = None + video_model_name: Optional[str] = None + vision_model_id: Optional[UUID] = None + vision_model_name: Optional[str] = None + audio_model_id: Optional[UUID] = None + audio_model_name: Optional[str] = None llm_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict) diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 046b79e7..711b6de9 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -8,9 +8,6 @@ import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator - - - # ============================================================================ # 从 json_schema.py 迁移的 Schema # ============================================================================ @@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel): class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") + data: List[BaseDataSchema] = Field(..., + description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") - memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, + description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, + description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") def _normalize_data(cls, v): @@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel): - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( - ..., + ..., description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") - resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") - change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, + description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, + description="List of detailed change records with IDs and field information.") class SingleReflexionResultSchema(BaseModel): @@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel): resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" - results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") + results: List[SingleReflexionResultSchema] = Field(..., + description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") def _normalize_resolved(cls, v): @@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel): # Composite key identifying a config row class ConfigKey(BaseModel): # 配置参数键模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") - user_id: str = Field("user_id", description="用户标识(字符串)") - apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") + user_id: str | None = Field(default=None, description="用户标识(字符串)") + apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)") # Allowed chunking strategies (extendable later) @@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_name: str = Field("配置名称", description="配置名称(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") - + # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") - + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") - + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") + + class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 @@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") + audio_id: Optional[str] = Field(None, description="语音模型ID") + vision_id: Optional[str] = Field(None, description="视觉模型ID") + video_id: Optional[str] = Field(None, description="视频模型ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") enable_llm_dedup_blockwise: Optional[bool] = None @@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 # 遗忘引擎配置参数更新模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) def fail( - msg: str, - error_code: str = "ERROR", - data: Optional[Any] = None, - time: Optional[int] = None, - query_preview: Optional[str] = None, + msg: str, + error_code: str = "ERROR", + data: Optional[Any] = None, + time: Optional[int] = None, + query_preview: Optional[str] = None, ) -> ApiResponse: payload = data if query_preview is not None: @@ -387,12 +397,13 @@ def fail( time=time or _now_ms(), ) + class GenerateCacheRequest(BaseModel): """缓存生成请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: Optional[str] = Field( - None, + None, description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" ) @@ -404,7 +415,7 @@ class GenerateCacheRequest(BaseModel): class ForgettingTriggerRequest(BaseModel): """手动触发遗忘周期请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") @@ -413,7 +424,7 @@ class ForgettingTriggerRequest(BaseModel): class ForgettingConfigResponse(BaseModel): """遗忘引擎配置响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") decay_constant: float = Field(..., description="衰减常数 d") lambda_time: float = Field(..., description="时间衰减参数") @@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel): """遗忘引擎配置更新请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") @@ -448,7 +459,7 @@ class ForgettingConfigUpdateRequest(BaseModel): class ForgettingCycleHistoryPoint(BaseModel): """遗忘周期历史数据点模型(用于趋势图)""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + date: str = Field(..., description="日期(格式: '1/1', '1/2')") merged_count: int = Field(..., description="每日融合节点数") average_activation: Optional[float] = Field(None, description="平均激活值") @@ -459,7 +470,7 @@ class ForgettingCycleHistoryPoint(BaseModel): class PendingForgettingNode(BaseModel): """待遗忘节点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + node_id: str = Field(..., description="节点ID") node_type: str = Field(..., description="节点类型:statement/entity/summary") content_summary: str = Field(..., description="内容摘要") @@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="forbid") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") node_distribution: Dict[str, int] = Field(..., description="节点类型分布") - recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") + recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., + description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") @@ -480,7 +492,7 @@ class ForgettingStatsResponse(BaseModel): class ForgettingReportResponse(BaseModel): """遗忘周期报告响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + merged_count: int = Field(..., description="融合的节点对数量") nodes_before: int = Field(..., description="遗忘前的节点总数") nodes_after: int = Field(..., description="遗忘后的节点总数") @@ -495,7 +507,7 @@ class ForgettingReportResponse(BaseModel): class ForgettingCurvePoint(BaseModel): """遗忘曲线数据点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + day: int = Field(..., description="天数") activation: float = Field(..., description="激活值") retention_rate: float = Field(..., description="保持率(与激活值相同)") @@ -504,7 +516,7 @@ class ForgettingCurvePoint(BaseModel): class ForgettingCurveRequest(BaseModel): """遗忘曲线请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") @@ -513,6 +525,6 @@ class ForgettingCurveRequest(BaseModel): class ForgettingCurveResponse(BaseModel): """遗忘曲线响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") config: Dict[str, Any] = Field(..., description="使用的配置参数") diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 058f082d..668a84a8 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase): updated_at: datetime.datetime api_keys: List["ModelApiKey"] = [] + @staticmethod + def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str: + if not key or len(key) <= prefix + suffix: + return "*" * len(key) + return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:] + @field_validator("api_keys", mode="after") @classmethod def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: @@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase): def _serialize_created_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("api_keys", when_used="json") + def _serialize_api_keys(self, api_keys: List["ModelApiKey"]): + result = [] + for api_key in api_keys: + data = api_key.model_dump() + data["api_key"] = self.mask_api_key(api_key.api_key) + result.append(data) + return result + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase): if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict): self.model_config_ids = [ mc.id for mc in self.model_configs - if hasattr(mc, 'id') - and not getattr(mc, 'is_composite', False) - and getattr(mc, 'name', None) == self.model_name + if hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name ] # 情况2:字典列表 elif isinstance(self.model_configs, list): self.model_config_ids = [ mc['id'] if isinstance(mc, dict) else mc.id for mc in self.model_configs - if ((isinstance(mc, dict) - and 'id' in mc + if ((isinstance(mc, dict) + and 'id' in mc and not mc.get('is_composite', False) - and mc.get('name') == self.model_name) or - (hasattr(mc, 'id') + and mc.get('name') == self.model_name) or + (hasattr(mc, 'id') and not getattr(mc, 'is_composite', False) and getattr(mc, 'name', None) == self.model_name)) ] @@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase): validate_assignment=True # 确保赋值触发校验 ) - @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel): """模型配置查询Schema""" type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") + capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)") is_active: Optional[bool] = Field(None, description="激活状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) @@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel): is_composite: Optional[bool] = Field(None, description="组合模型筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelMarketplace(BaseModel): """模型广场响应Schema""" llm_models: List[ModelConfig] = [] @@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel): class ModelBase(BaseModel): """基础模型Schema""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID name: str type: str @@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelInfo(BaseModel): """模型信息Schema""" model_name: str = Field(..., description="模型名称") @@ -336,4 +353,3 @@ class ModelInfo(BaseModel): is_omni: bool = Field(default=False, description="是否为omni模型") model_type: ModelType = Field(..., description="模型类型") capability: List[str] = Field(default_factory=list, description="模型能力列表") - diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 6fcf680b..3dda6fc0 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -82,6 +82,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -93,7 +99,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) memory_flag = False if memory: memory_tools, memory_flag = self.agent_service.load_memory_config( @@ -129,45 +136,18 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [] - for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - files = msg.meta_data.get("files", []) - # 使用 MultimodalService 处理文件 - multimodal_service = MultimodalService(self.db, api_config=model_info) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - history_processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(history_processed_files) - - history.append({ - "role": msg.role, - "content": content - }) # 处理多模态文件 processed_files = None if files: multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -206,7 +186,8 @@ class AppChatService: # 构建用户消息内容(含多模态文件) human_meta = { - "files": [] + "files": [], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -221,6 +202,13 @@ class AppChatService: "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } + # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url @@ -249,8 +237,9 @@ class AppChatService: }), "elapsed_time": elapsed_time, "suggested_questions": suggested_questions, - "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), + "citations": self.agent_service._filter_citations(features_config, citations_collector), "audio_url": audio_url, + "audio_status": "pending" } async def agnet_chat_stream( @@ -301,6 +290,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -313,7 +308,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -350,45 +346,18 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [] - for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - history_files = msg.meta_data.get("files", []) - # 使用 MultimodalService 处理文件 - multimodal_service = MultimodalService(self.db, api_config=model_info) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in history_files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - history_processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(history_processed_files) - - history.append({ - "role": msg.role, - "content": content - }) # 处理多模态文件 processed_files = None if files: multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态),同时并行启动 TTS @@ -433,7 +402,7 @@ class AppChatService: elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件(包含 suggested_questions、tts、citations) + # 发送结束事件(包含 suggested_questions、tts、audio_status、citations) end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} sq_config = features_config.get("suggested_questions_after_answer", {}) if isinstance(sq_config, dict) and sq_config.get("enabled"): @@ -443,11 +412,23 @@ class AppChatService: "api_base": api_key_obj.api_base}, {} ) end_data["audio_url"] = stream_audio_url - end_data["citations"] = self.agent_service._filter_citations(features_config, []) + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None + end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector) # 保存消息 human_meta = { - "files":[] + "files":[], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -457,11 +438,16 @@ class AppChatService: if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 19aaac42..4dcabff8 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1638,7 +1638,7 @@ class AppService: # ==================== 记忆配置提取方法 ==================== - def _extract_memory_config_id( + def _get_memory_config_id_from_release( self, app_type: str, config: Dict[str, Any] @@ -1863,7 +1863,7 @@ class AppService: self.db.flush() # 先 flush,确保 release 已插入数据库 # 提取记忆配置ID并更新终端用户 - memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) + memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(app.type, config) # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: @@ -2001,7 +2001,7 @@ class AppService: raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") # 提取记忆配置ID并更新终端用户 - memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) + memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(release.type, release.config) # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index f8a01a40..014d96b7 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -274,7 +274,8 @@ class ConversationService: self, conversation_id: uuid.UUID, max_history: Optional[int] = None, - api_config: Optional[ModelInfo] = None + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -282,7 +283,8 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. - api_config (Optional[ModelInfo]): Model API configuration for multimodal processing. + current_provider (Optional[str]): Current provider for file handling. + current_is_omni (Optional[bool]): Current omni flag for file handling. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -292,38 +294,30 @@ class ConversationService: limit=max_history ) - # 转换为字典格式 history = [] for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - files = msg.meta_data.get("files", []) - if api_config: - # 使用 MultimodalService 处理文件 - from app.services.multimodal_service import MultimodalService - multimodal_service = MultimodalService(self.db, api_config=api_config) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(processed_files) - - history.append({ + msg_dict = { "role": msg.role, - "content": content - }) + "content": [{"type": "text", "text": msg.content}] + } + + # 处理用户消息中的多模态文件 + if msg.role == "user" and msg.meta_data: + history_files = msg.meta_data.get("history_files", {}) + + if history_files and current_provider and current_is_omni is not None: + # 检查是否需要重新处理文件 + stored_provider = history_files.get("provider") + stored_is_omni = history_files.get("is_omni") + + # 如果provider或is_omni不匹配,需要重新处理 + if stored_provider != current_provider or stored_is_omni != current_is_omni: + continue + + # provider和is_omni匹配,直接使用存储的内容 + msg_dict["content"].extend(history_files.get("content")) + + history.append(msg_dict) return history @@ -539,6 +533,7 @@ class ConversationService: provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -546,7 +541,8 @@ class ConversationService: model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) @@ -554,15 +550,8 @@ class ConversationService: conversation_messages = await self.get_conversation_history( conversation_id=conversation_id, max_history=20, - api_config=ModelInfo( - model_name=model_name, - provider=provider, - api_key=api_key, - api_base=api_base, - capability=api_config.capability, - is_omni=api_config.is_omni, - model_type=model_type - ) + current_provider=provider, + current_is_omni=is_omni ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 5989f0f8..ac34b4de 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository -from app.schemas.app_schema import FileInput +from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service @@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: kb_config: 知识库配置 kb_ids: 知识库ID列表 user_id: 用户ID + citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) + 列表元素类型为 Citation,包含字段: + - document_id: 文档唯一标识 + - file_name: 文件名 + - knowledge_id: 知识库 ID + - score: 检索相关性得分 Returns: 检索到的相关知识内容 @@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): } ) + # 收集引用信息 + if citations_collector is not None: + seen_doc_ids = {c.get("document_id") for c in citations_collector} + for chunk in retrieve_chunks_result: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations_collector.append(Citation( + document_id=doc_id, + file_name=meta.get("file_name", ""), + knowledge_id=str(meta.get("knowledge_id", "")), + score=meta.get("score", 0) + )) + return f"检索到以下相关信息:\n\n{context}" else: logger.warning("知识库检索未找到结果") @@ -320,26 +341,26 @@ class AgentRunService: self, knowledge_retrieval_config: dict | None, user_id - ) -> list: + ) -> tuple[list, list]: + """返回 (tools, citations_collector)""" if not knowledge_retrieval_config: - return [] + return [], [] + citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + kb_tool = create_knowledge_retrieval_tool( + knowledge_retrieval_config, kb_ids, user_id, + citations_collector=citations_collector + ) tools.append(kb_tool) - logger.debug( "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } + extra={"kb_ids": kb_ids, "tool_count": len(tools)} ) - return tools + return tools, citations_collector def load_memory_config( self, @@ -441,12 +462,12 @@ class AgentRunService: @staticmethod def _filter_citations( features_config: Dict[str, Any], - citations: List[Any] + citations: List[Citation] ) -> List[Any]: """根据 citation 开关决定是否返回引用来源""" citation_cfg = features_config.get("citation", {}) if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): - return citations + return [cit.model_dump() for cit in citations] return [] async def run( @@ -549,7 +570,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -592,8 +614,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - api_config=model_info, - max_history=10 + max_history=10, + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -602,7 +625,7 @@ class AgentRunService: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -661,7 +684,10 @@ class AgentRunService: }) }, files=files, - audio_url=audio_url + processed_files=processed_files, + audio_url=audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) response = { @@ -676,8 +702,9 @@ class AgentRunService: "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], - "citations": self._filter_citations(features_config, result.get("citations", [])), + "citations": self._filter_citations(features_config, citations_collector), "audio_url": audio_url, + "audio_status": "pending" } logger.info( @@ -785,7 +812,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -830,8 +858,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - api_config=model_info, - max_history=memory_config.get("max_history", 10) + max_history=memory_config.get("max_history", 10), + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -840,7 +869,7 @@ class AgentRunService: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -909,10 +938,13 @@ class AgentRunService: "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} }, files=files, - audio_url=stream_audio_url + processed_files=processed_files, + audio_url=stream_audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) - # 12. 发送结束事件(包含 suggested_questions 和 tts) + # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, @@ -923,7 +955,18 @@ class AgentRunService: features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url - end_data["citations"] = self._filter_citations(features_config, []) + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None + end_data["citations"] = self._filter_citations(features_config, citations_collector) yield self._format_sse_event("end", end_data) logger.info( @@ -1119,14 +1162,17 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, - api_config: ModelInfo | None = None, - max_history: int = 10 + max_history: int = 10, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: - """加载会话历史消息 + """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 + current_provider: 当前模型的provider + current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 @@ -1138,7 +1184,8 @@ class AgentRunService: history = await conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), max_history=max_history, - api_config=api_config + current_provider=current_provider, + current_is_omni=current_is_omni ) logger.debug( @@ -1166,7 +1213,10 @@ class AgentRunService: app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, - audio_url: Optional[str] = None + processed_files: Optional[List[Dict[str, Any]]] = None, + audio_url: Optional[str] = None, + provider: Optional[str] = None, + is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) @@ -1177,6 +1227,11 @@ class AgentRunService: app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 + files: 原始文件输入 + processed_files: 处理后的文件 + audio_url: 音频URL + provider: 模型供应商 + is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService @@ -1186,15 +1241,24 @@ class AgentRunService: # 保存消息(会话已经存在) human_meta = { - "files": [] + "files": [], + "history_files": {} } if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + + # 保存 history_files,包含 provider 和 is_omni 信息 + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": provider, + "is_omni": is_omni + } + # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, @@ -1420,8 +1484,9 @@ class AgentRunService: workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 - 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。 + 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 + 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): @@ -1808,6 +1873,7 @@ class AgentRunService: ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), + "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None @@ -1885,6 +1951,7 @@ class AgentRunService: "results": [{ **r, "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], @@ -2016,6 +2083,7 @@ class AgentRunService: full_content = "" returned_conversation_id = model_conversation_id audio_url = None + audio_status = None citations = [] suggested_questions = [] @@ -2074,6 +2142,7 @@ class AgentRunService: # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") + audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) @@ -2103,6 +2172,7 @@ class AgentRunService: "message": full_content, "elapsed_time": elapsed, "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None @@ -2117,6 +2187,7 @@ class AgentRunService: "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() @@ -2253,6 +2324,7 @@ class AgentRunService: "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") diff --git a/api/app/services/file_storage_service.py b/api/app/services/file_storage_service.py index 2ebc5d9a..5897936b 100644 --- a/api/app/services/file_storage_service.py +++ b/api/app/services/file_storage_service.py @@ -325,27 +325,30 @@ class FileStorageService: ) raise - async def get_file_url(self, file_key: str, expires: int = 3600) -> str: + async def get_file_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get an access URL for a file. Args: file_key: The file key. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: URL for accessing the file. """ logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") - try: - url = await self.storage.get_url(file_key, expires) + url = await self.storage.get_url(file_key, expires, file_name=file_name) logger.debug(f"File URL generated: file_key={file_key}") return url except Exception as e: - logger.error( - f"Error getting file URL: file_key={file_key}, error={str(e)}" - ) + logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}") raise diff --git a/api/app/services/generation_service.py b/api/app/services/generation_service.py new file mode 100644 index 00000000..2505793c --- /dev/null +++ b/api/app/services/generation_service.py @@ -0,0 +1,162 @@ +""" +图片和视频生成服务 + +提供统一的生成接口,支持多种 Provider +""" +from typing import Dict, Any, Optional +from sqlalchemy.orm import Session +import uuid + +from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelType +from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository +from app.services.model_service import ModelApiKeyService + + +class GenerationService: + """生成服务""" + + def __init__(self, db: Session): + self.db = db + + async def generate_image( + self, + model_config_id: str, + prompt: str, + size: Optional[str] = "2k", + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + size: 图片尺寸 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.IMAGE: + raise BusinessException( + f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成图片 + generator = RedBearImageGenerator(config) + result = await generator.agenerate(prompt, size, **kwargs) + + return result + + async def generate_video( + self, + model_config_id: str, + prompt: str, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + duration: 视频时长(秒) + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID) + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.VIDEO: + raise BusinessException( + f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成视频 + generator = RedBearVideoGenerator(config) + result = await generator.agenerate(prompt, duration, **kwargs) + + return result + + async def get_video_task_status( + self, + model_config_id: str, + task_id: str + ) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + model_config_id: 模型配置ID + task_id: 任务ID + + Returns: + 任务状态信息 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 查询任务状态 + generator = RedBearVideoGenerator(config) + result = await generator.aget_task_status(task_id) + + return result diff --git a/api/app/services/home_page_service.py b/api/app/services/home_page_service.py index 8326ad40..4e6bf664 100644 --- a/api/app/services/home_page_service.py +++ b/api/app/services/home_page_service.py @@ -94,29 +94,38 @@ class HomePageService: @staticmethod def load_version_introduction(version: str) -> Dict[str, Any]: """ - 从 JSON 文件加载对应版本的介绍 + 加载对应版本的介绍(优先从数据库读取,fallback 到 JSON 文件) :param version: 系统版本号(如 "0.2.0") :return: 对应版本的详细介绍 """ - # 2. 定义 JSON 文件路径(简化路径处理,保留绝对路径调试特性) + from copy import deepcopy + from app.db import SessionLocal + from app.repositories.home_page_repository import HomePageRepository + + result = deepcopy(HomePageService.DEFAULT_RETURN_DATA) + + try: + db = SessionLocal() + try: + db_result = HomePageRepository.get_version_introduction(db, version) + if db_result: + return db_result + finally: + db.close() + except Exception as e: + pass + json_abs_path = Path(__file__).parent.parent / "version_info.json" json_abs_path = json_abs_path.resolve() - # 3. 初始化返回结果(深拷贝默认模板,避免修改原常量) - from copy import deepcopy - result = deepcopy(HomePageService.DEFAULT_RETURN_DATA) - try: - # 4. 简化文件存在性判断(合并逻辑,减少分支) if not json_abs_path.exists(): result["message"] = f"版本介绍文件不存在:{json_abs_path}" return result - # 5. 读取并解析 JSON 文件(简化文件操作流程) with open(json_abs_path, "r", encoding="utf-8") as f: changelogs = json.load(f) - # 6. 简化版本匹配逻辑,直接返回结果或更新提示信息 if version in changelogs: return changelogs[version] result["message"] = f"暂未查询到 {version} 版本的详细介绍" diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..289fd74c 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID import redis -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session +from app.cache import InterestMemoryCache from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import ( merge_multiple_search_results, reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas import FileInput from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from app.services.memory_perceptual_service import MemoryPerceptualService try: from app.core.memory.utils.log.audit_logger import audit_logger @@ -267,8 +270,16 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, - db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory( + self, + end_user_id: str, + messages: list[dict], + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" + ) -> str: """ Process write operation with config_id @@ -297,8 +308,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError( - f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. " + f"Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -334,48 +345,58 @@ class MemoryAgentService: raise ValueError(error_msg) + perceptual_serivce = MemoryPerceptualService(db) + for message in messages: + message["file_content"] = [] + for file in (message.get("files") or []): + file_object = await perceptual_serivce.generate_perceptual_memory( + end_user_id=end_user_id, + memory_config=memory_config, + file=FileInput(**file) + ) + if file_object is None: + continue + message["file_content"].append((file_object, file["type"])) + logger.info(messages) + + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: if storage_type == "rag": # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result + await write_rag(end_user_id, message_text, user_rag_memory_id) + return "success" else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config, - "language": language + await write_neo4j( + end_user_id=end_user_id, + messages=messages, + memory_config=memory_config, + ref_id='', + language=language + ) + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id, lang + ) + if deleted: + logger.info( + f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + for message in messages: + message["file_content"] = [ + perceptual[0].file_path for perceptual in message["file_content"] + ] + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name } - - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, - contents) + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" @@ -586,7 +607,7 @@ class MemoryAgentService: retrieved_content.append({query: statements}) # 如果 retrieved_content 为空,设置为空字符串 - if retrieved_content == []: + if not retrieved_content: retrieved_content = '' # 只有当回答不是"信息不足"且不是快速检索时才保存 diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 01bc6267..9282fc28 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -28,7 +28,7 @@ class MemoryAPIService: 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ - + def __init__(self, db: Session): """Initialize MemoryAPIService. @@ -36,11 +36,11 @@ class MemoryAPIService: db: SQLAlchemy database session """ self.db = db - + def validate_end_user( - self, - end_user_id: str, - workspace_id: uuid.UUID + self, + end_user_id: str, + workspace_id: uuid.UUID ) -> EndUser: """Validate that end_user exists and belongs to the workspace. @@ -56,7 +56,7 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace """ logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}") - + # Query end_user by ID try: end_user_uuid = uuid.UUID(end_user_id) @@ -66,7 +66,7 @@ class MemoryAPIService: message=f"Invalid end_user_id format: {end_user_id}", code=BizCode.INVALID_PARAMETER ) - + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() if not end_user: @@ -75,13 +75,13 @@ class MemoryAPIService: resource_type="EndUser", resource_id=end_user_id ) - + # Verify end_user belongs to the workspace via App relationship app = self.db.query(App).filter( App.id == end_user.app_id, App.is_active.is_(True) ).first() - + if not app: logger.warning(f"App not found for end_user: {end_user_id}") # raise ResourceNotFoundException( @@ -99,7 +99,7 @@ class MemoryAPIService: # message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}", # code=BizCode.FORBIDDEN # ) - + logger.info(f"End user {end_user_id} validated successfully") return end_user @@ -125,13 +125,13 @@ class MemoryAPIService: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") async def write_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - config_id: str, - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -154,13 +154,13 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace or write fails """ logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - + try: # Delegate to MemoryAgentService # Convert string message to list[dict] format expected by MemoryAgentService @@ -171,11 +171,11 @@ class MemoryAPIService: config_id=config_id, db=self.db, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id or "" + user_rag_memory_id=user_rag_memory_id or "", ) - + logger.info(f"Memory write successful for end_user: {end_user_id}") - + # result may be a string "success" or a dict with a "status" key # Preserve the full dict so callers don't silently lose extra fields # (e.g. error codes, metadata) returned by MemoryAgentService. @@ -189,7 +189,7 @@ class MemoryAPIService: "status": result if isinstance(result, str) else "success", "end_user_id": end_user_id, } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -204,16 +204,16 @@ class MemoryAPIService: message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - + async def read_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - search_switch: str = "0", - config_id: str = "", - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Read memory with validation. @@ -237,14 +237,13 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace or read fails """ logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( @@ -257,15 +256,15 @@ class MemoryAPIService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "" ) - + logger.info(f"Memory read successful for end_user: {end_user_id}") - + return { "answer": result.get("answer", ""), "intermediate_outputs": result.get("intermediate_outputs", []), "end_user_id": end_user_id } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -282,8 +281,8 @@ class MemoryAPIService: ) def list_memory_configs( - self, - workspace_id: uuid.UUID, + self, + workspace_id: uuid.UUID, ) -> Dict[str, Any]: """List all memory configs for a workspace. diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index a3751c07..66c110b1 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None): """Validate configuration ID format (supports both UUID and integer).""" if isinstance(config_id, uuid.UUID): return config_id - + if config_id is None: raise InvalidConfigError( "Configuration ID cannot be None", @@ -52,26 +52,30 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - # 如果提供了数据库会话,尝试通过 user_id 查询 config_id + # 如果提供了数据库会话,尝试通过 config_id_old 查询 config_id if db is not None: - # 查询 user_id 匹配的记录 - stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == str(config_id)) + # 查询 config_id_old 匹配的记录 + stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == config_id) result = db.execute(stmt).scalars().first() if result: - logger.info(f"Found config_id {result.config_id} for user_id {config_id}") + logger.info(f"Found config_id {result.config_id} for config_id_old {config_id}") return result.config_id - return config_id + raise InvalidConfigError( + f"未找到 config_id_old={config_id} 对应的配置", + field_name="config_id", + invalid_value=config_id, + ) if isinstance(config_id, str): config_id_stripped = config_id.strip() - + # Try parsing as UUID first try: return uuid.UUID(config_id_stripped) except ValueError: pass - + # Fall back to integer parsing try: parsed_id = int(config_id_stripped) @@ -81,18 +85,22 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - + # 如果提供了数据库会话,尝试通过 user_id 查询 config_id if db is not None: - # 查询 user_id 匹配的记录 - stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id)) + # 查询 config_id_old 匹配的记录 + stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == parsed_id) result = db.execute(stmt).scalars().first() - + if result: - logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}") + logger.info(f"Found config_id {result.config_id} for config_id_old {parsed_id}") return result.config_id - return parsed_id + raise InvalidConfigError( + f"未找到 config_id_old={parsed_id} 对应的配置", + field_name="config_id", + invalid_value=config_id, + ) except ValueError: raise InvalidConfigError( f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)", @@ -154,10 +162,10 @@ class MemoryConfigService: self.db = db def load_memory_config( - self, - config_id: Optional[UUID] = None, - workspace_id: Optional[UUID] = None, - service_name: str = "MemoryConfigService", + self, + config_id: Optional[UUID] = None, + workspace_id: Optional[UUID] = None, + service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database with optional fallback. @@ -194,14 +202,14 @@ class MemoryConfigService: try: # Use get_config_with_fallback if workspace_id is provided memory_config = None + validated_config_id = None if workspace_id: - validated_config_id = None if config_id: try: validated_config_id = _validate_config_id(config_id, self.db) except Exception: validated_config_id = None - + memory_config = self.get_config_with_fallback( memory_config_id=validated_config_id, workspace_id=workspace_id @@ -210,7 +218,7 @@ class MemoryConfigService: validated_config_id = _validate_config_id(config_id, self.db) from app.models.memory_config_model import MemoryConfig as MemoryConfigModel memory_config = self.db.get(MemoryConfigModel, validated_config_id) - + if not memory_config: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -233,7 +241,7 @@ class MemoryConfigService: result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id) db_query_time = time.time() - db_query_start logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") - + if not result: raise ConfigurationError( f"Workspace not found for config {memory_config.config_id}" @@ -243,10 +251,10 @@ class MemoryConfigService: # Helper function to validate model with workspace fallback def _validate_model_with_fallback( - model_id: str, - model_type: str, - workspace_default: str, - required: bool = False + model_id: str, + model_type: str, + workspace_default: str, + required: bool = False ) -> tuple: """Validate model ID, falling back to workspace default if invalid. @@ -275,7 +283,7 @@ class MemoryConfigService: logger.warning( f"{model_type} model validation failed, trying workspace default: {e}" ) - + # Fallback to workspace default if workspace_default: try: @@ -297,7 +305,7 @@ class MemoryConfigService: logger.error(f"Workspace default {model_type} model also invalid: {e}") if required: raise - + if required: raise InvalidConfigError( f"{model_type.title()} model is required but not configured", @@ -306,7 +314,7 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id ) - + return None, None # Step 2: Validate embedding model with workspace fallback @@ -343,6 +351,35 @@ class MemoryConfigService: if memory_config.rerank_id or workspace.rerank: logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") + vision_uuid, vision_name = validate_and_resolve_model_id( + memory_config.vision_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + audio_uuid, audio_name = validate_and_resolve_model_id( + memory_config.audio_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + video_uuid, video_name = validate_and_resolve_model_id( + memory_config.video_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) # Create immutable MemoryConfig object config = MemoryConfig( config_id=memory_config.config_id, @@ -356,6 +393,12 @@ class MemoryConfigService: embedding_model_name=embedding_name, rerank_model_id=rerank_uuid, rerank_model_name=rerank_name, + video_model_id=video_uuid, + video_model_name=video_name, + vision_model_id=vision_uuid, + vision_model_name=vision_name, + audio_model_id=audio_uuid, + audio_model_name=audio_name, storage_type=workspace.storage_type or "neo4j", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", reflexion_enabled=memory_config.enable_self_reflexion or False, @@ -364,24 +407,31 @@ class MemoryConfigService: reflexion_baseline=memory_config.baseline or "Time", loaded_at=datetime.now(), # Pipeline config: Deduplication - enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, - enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, + enable_llm_dedup_blockwise=bool( + memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, + enable_llm_disambiguation=bool( + memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, # Pipeline config: Statement extraction - statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, - include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, - max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, + statement_granularity=int( + memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, + include_dialogue_context=bool( + memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, + max_dialogue_context_chars=int( + memory_config.max_context) if memory_config.max_context is not None else 1000, # Pipeline config: Forgetting engine lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, # Pipeline config: Pruning - pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, + pruning_enabled=bool( + memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_scene=memory_config.pruning_scene or "education", - pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, + pruning_threshold=float( + memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), @@ -448,9 +498,9 @@ class MemoryConfigService: if not config: logger.warning(f"Model ID {model_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -481,9 +531,9 @@ class MemoryConfigService: if not config: logger.warning(f"Embedding model ID {embedding_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -571,25 +621,25 @@ class MemoryConfigService: """ from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.repositories.ontology_class_repository import OntologyClassRepository - + if not memory_config.scene_id: logger.debug("No scene_id configured, skipping ontology type fetch") return None - + try: ontology_repo = OntologyClassRepository(self.db) ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id) - + if not ontology_classes: logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}") return None - + ontology_types = OntologyTypeList.from_db_models(ontology_classes) logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" ) return ontology_types - + except Exception as e: logger.warning( f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}", @@ -598,8 +648,8 @@ class MemoryConfigService: return None def get_workspace_default_config( - self, - workspace_id: UUID + self, + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get workspace default memory config. @@ -613,19 +663,19 @@ class MemoryConfigService: Optional[MemoryConfigModel]: Default config or None if no configs exist """ config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id) - + if not config: logger.warning( "No active memory config found for workspace fallback", extra={"workspace_id": str(workspace_id)} ) - + return config def get_config_with_fallback( - self, - memory_config_id: Optional[UUID], - workspace_id: UUID + self, + memory_config_id: Optional[UUID], + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get memory config with fallback to workspace default. @@ -644,13 +694,13 @@ class MemoryConfigService: "No memory config ID provided, using workspace default", extra={"workspace_id": str(workspace_id)} ) - + config = MemoryConfigRepository.get_with_fallback( self.db, memory_config_id, workspace_id ) - + if not config and memory_config_id: logger.warning( "Memory config not found, falling back to workspace default", @@ -659,13 +709,13 @@ class MemoryConfigService: "workspace_id": str(workspace_id) } ) - + return config def delete_config( - self, - config_id: UUID | int, - force: bool = False + self, + config_id: UUID | int, + force: bool = False ) -> dict: """Delete memory config with protection against in-use configs. @@ -687,7 +737,7 @@ class MemoryConfigService: from app.core.exceptions import ResourceNotFoundException from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.repositories.end_user_repository import EndUserRepository - + # 处理旧格式 int 类型的 config_id if isinstance(config_id, int): logger.warning( @@ -699,11 +749,11 @@ class MemoryConfigService: "message": "旧格式配置ID不支持删除操作,请使用新版配置", "legacy_int_id": config_id } - + config = self.db.get(MemoryConfigModel, config_id) if not config: raise ResourceNotFoundException("MemoryConfig", str(config_id)) - + # Check if this is the default config - default configs cannot be deleted if config.is_default: logger.warning( @@ -715,11 +765,11 @@ class MemoryConfigService: "message": "默认配置不允许删除", "is_default": True } - + # Use repository to count connected end users end_user_repo = EndUserRepository(self.db) connected_count = end_user_repo.count_by_memory_config_id(config_id) - + if connected_count > 0 and not force: logger.warning( "Attempted to delete memory config with connected end users", @@ -728,18 +778,18 @@ class MemoryConfigService: "connected_count": connected_count } ) - + return { "status": "warning", "message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置", "connected_count": connected_count, "force_required": True } - + # Force delete: use repository to clear end user references first if connected_count > 0 and force: cleared_count = end_user_repo.clear_memory_config_id(config_id) - + logger.warning( "Force deleting memory config, clearing end user references", extra={ @@ -747,11 +797,11 @@ class MemoryConfigService: "cleared_end_users": cleared_count } ) - + try: self.db.delete(config) self.db.commit() - + logger.info( "Memory config deleted", extra={ @@ -760,16 +810,16 @@ class MemoryConfigService: "affected_users": connected_count } ) - + return { "status": "success", "message": "记忆配置删除成功", "affected_users": connected_count } - + except IntegrityError as e: self.db.rollback() - + # Handle foreign key violation gracefully error_str = str(e.orig) if e.orig else str(e) if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower(): @@ -785,7 +835,7 @@ class MemoryConfigService: "message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除", "force_required": True } - + # Re-raise other integrity errors logger.error( "Delete failed due to integrity error", @@ -800,9 +850,9 @@ class MemoryConfigService: # ==================== 记忆配置提取方法 ==================== def extract_memory_config_id( - self, - app_type: str, - config: dict + self, + app_type: str, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(根据应用类型分发) @@ -827,9 +877,26 @@ class MemoryConfigService: logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}") return None, False + def _resolve_config_id_old(self, config_id_old: int) -> Optional[uuid.UUID]: + """通过 config_id_old 查询对应的 UUID config_id。 + + Args: + config_id_old: 旧格式的整数配置ID + + Returns: + 对应的 UUID config_id,未找到返回 None + """ + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + result = self.db.query(MemoryConfigModel).filter( + MemoryConfigModel.config_id_old == config_id_old + ).first() + if result: + return result.config_id + return None + def _extract_memory_config_id_from_agent( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Agent 应用配置中提取 memory_config_id @@ -858,10 +925,11 @@ class MemoryConfigService: elif isinstance(memory_value, str): # Check if it's a numeric string (legacy int format) if memory_value.isdigit(): - logger.warning( - f"Agent 配置中 memory_config_id 为旧格式 int 字符串,将使用工作空间默认配置: " - f"value={memory_value}" - ) + resolved = self._resolve_config_id_old(int(memory_value)) + if resolved: + logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}") + return resolved, False + logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置") return None, True try: return uuid.UUID(memory_value), False @@ -869,11 +937,11 @@ class MemoryConfigService: logger.warning(f"Invalid UUID string: {memory_value}") return None, False elif isinstance(memory_value, int): - # 旧数据存储为 int,需要回退到工作空间默认配置 - logger.warning( - f"Agent 配置中 memory_config_id 为旧格式 int,将使用工作空间默认配置: " - f"value={memory_value}" - ) + resolved = self._resolve_config_id_old(memory_value) + if resolved: + logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}") + return resolved, False + logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置") return None, True else: logger.warning( @@ -888,8 +956,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_workflow( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Workflow 应用配置中提取 memory_config_id @@ -905,14 +973,14 @@ class MemoryConfigService: - is_legacy_int: 是否检测到旧格式 int 数据 """ nodes = config.get("nodes", []) - + for node in nodes: node_type = node.get("type", "") - + # 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite) if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]: config_id = node.get("config", {}).get("config_id") - + if config_id: try: # 处理字符串、UUID 和 int(旧数据兼容)三种情况 @@ -921,10 +989,16 @@ class MemoryConfigService: elif isinstance(config_id, str): return uuid.UUID(config_id), False elif isinstance(config_id, int): - # 旧数据存储为 int,需要回退到工作空间默认配置 + resolved = self._resolve_config_id_old(config_id) + if resolved: + logger.info( + f"Resolved workflow legacy config_id_old={config_id} to config_id={resolved}: " + f"node_id={node.get('id')}, node_type={node_type}" + ) + return resolved, False logger.warning( - f"工作流记忆节点 config_id 为旧格式 int,将使用工作空间默认配置: " - f"node_id={node.get('id')}, node_type={node_type}, value={config_id}" + f"未找到工作流记忆节点 config_id_old={config_id} 对应的配置,将使用工作空间默认配置: " + f"node_id={node.get('id')}, node_type={node_type}" ) return None, True else: @@ -937,6 +1011,6 @@ class MemoryConfigService: f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " f"node_type={node_type}, error={str(e)}" ) - + logger.debug("工作流配置中未找到记忆节点") return None, False diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index a0bcc1a1..11118571 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -315,6 +315,12 @@ class MemoryForgetService: # 获取遗忘引擎组件 _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + # 如果参数为 None,使用配置中的默认值 + if max_merge_batch_size is None: + max_merge_batch_size = config.get('max_merge_batch_size', 100) + if min_days_since_access is None: + min_days_since_access = config.get('min_days_since_access', 30) + # 记录执行开始时间 execution_time = datetime.now() diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index b8961d33..523adadb 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -341,7 +341,7 @@ async def memory_konwledges_up( ) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") + return db_document async def create_document_chunk( @@ -350,7 +350,7 @@ async def create_document_chunk( create_data: ChunkCreate, db: Session, current_user: User -): +) -> DocumentChunk: """ 创建文档块 @@ -439,10 +439,10 @@ async def create_document_chunk( db_document.chunk_num += 1 db.commit() - return success(data=chunk, msg="文档块创建成功") + return chunk -async def write_rag(end_user_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk: """ 将消息写入 RAG 知识库 @@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id): document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======', document) api_logger.info(f"查找文档结果: document_id={document}") + create_chunks = ChunkCreate(content=message) if document is not None: # 文档已存在,直接添加新块 api_logger.info(f"文档已存在,添加新块: document_id={document}") - create_chunks = ChunkCreate(content=message) result = await create_document_chunk( kb_id=kb_uuid, document_id=uuid.UUID(document), @@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: # 文档不存在,创建新文档 api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") - result = await memory_konwledges_up( + document = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, create_data=create_data, db=db, current_user=current_user ) + result = await create_document_chunk( + kb_id=kb_uuid, + document_id=document.id, + create_data=create_chunks, + db=db, + current_user=current_user + ) # 重新查询刚创建的文档ID new_document_id = find_document_id_by_kb_and_filename( db=db, diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8a7c86e2..3ee238e2 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -12,11 +12,12 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models import FileMetadata +from app.models import FileMetadata, ModelApiKey, ModelType from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository -from app.schemas import FileType +from app.schemas import FileType, FileInput +from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, @@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import ( AudioModal, Content, VideoModal, TextModal ) from app.schemas.model_schema import ModelInfo +from app.services.model_service import ModelApiKeyService +from app.services.multimodal_service import MultimodalService business_logger = get_business_logger() @@ -195,21 +198,58 @@ class MemoryPerceptualService: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + def _get_mutlimodal_client( + self, + file_type: FileType, + config: MemoryConfig + ) -> tuple[RedBearLLM | None, ModelApiKey | None]: + model_config = None + if file_type == FileType.AUDIO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.audio_model_id + ) + elif file_type == FileType.VIDEO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.video_model_id + ) + elif file_type == FileType.DOCUMENT: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.llm_model_id + ) + elif file_type == FileType.IMAGE: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.vision_model_id + ) + llm = None + if model_config: + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ) + ) + return llm, model_config + async def generate_perceptual_memory( self, end_user_id: str, - model_config: ModelInfo, - file_type: str, - file_url: str, - file_message: dict, + memory_config: MemoryConfig, + file: FileInput ): - memories = self.repository.get_by_url(file_url) + memories = self.repository.get_by_url(file.url) if memories: - business_logger.info(f"Perceptual memory already exists: {file_url}") + business_logger.info(f"Perceptual memory already exists: {file.url}") if end_user_id not in [memory.end_user_id for memory in memories]: business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") memory_cache = memories[0] - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), perceptual_type=PerceptualType(memory_cache.perceptual_type), file_path=memory_cache.file_path, @@ -219,20 +259,33 @@ class MemoryPerceptualService: meta_data=memory_cache.meta_data ) self.db.commit() - - return - llm = RedBearLLM(RedBearModelConfig( + return memory + else: + for memory in memories: + if memory.end_user_id == uuid.UUID(end_user_id): + return memory + llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, api_key=model_config.api_key, - base_url=model_config.api_base, - is_omni=model_config.is_omni - ), type=model_config.model_type) + api_base=model_config.api_base, + is_omni=model_config.is_omni, + capability=model_config.capability, + model_type=ModelType.LLM + )) + file_message = await multimodel_service.process_files( + files=[file] + ) + if not file_message: + business_logger.warning(f"Unsupported file type {file}, model capability: {model_config.capability}") + return None + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) messages = [ @@ -242,8 +295,22 @@ class MemoryPerceptualService: ]} ] result = await llm.ainvoke(messages) - content = json_repair.repair_json(result.content, return_objects=True) - path = urlparse(file_url).path + content = result.content + final_output = "" + if isinstance(content, list): + for msg in content: + if isinstance(msg, dict): + final_output += msg.get("text", "") + elif isinstance(msg, str): + final_output += msg + elif isinstance(content, dict): + final_output += content.get("text", "") + elif isinstance(content, str): + final_output = content + else: + raise ValueError(f"Unexcept Model Output Type: {result.content}") + content = json_repair.repair_json(final_output, return_objects=True) + path = urlparse(file.url).path filename = os.path.basename(path) filename = unquote(filename) file_ext = os.path.splitext(filename)[1] @@ -252,21 +319,21 @@ class MemoryPerceptualService: stmt = select(FileMetadata).where( FileMetadata.id == file_id ) - file = self.db.execute(stmt).scalar_one_or_none() + file_obj = self.db.execute(stmt).scalar_one_or_none() - if file: - filename = file.file_name - file_ext = file.file_ext + if file_obj: + filename = file_obj.file_name + file_ext = file_obj.file_ext except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: - if file_type == FileType.AUDIO: + if file.type == FileType.AUDIO: file_ext = ".mp3" - elif file_type == FileType.VIDEO: + elif file.type == FileType.VIDEO: file_ext = ".mp4" - elif file_type == FileType.DOCUMENT: + elif file.type == FileType.DOCUMENT: file_ext = ".txt" - elif file_type == FileType.IMAGE: + elif file.type == FileType.IMAGE: file_ext = ".jpg" filename += file_ext file_content = { @@ -274,11 +341,11 @@ class MemoryPerceptualService: "topic": content.get("topic"), "domain": content.get("domain") } - if file_type in [FileType.IMAGE, FileType.VIDEO]: + if file.type in [FileType.IMAGE, FileType.VIDEO]: file_modalities = { "scene": content.get("scene", []) } - elif file_type in [FileType.DOCUMENT]: + elif file.type in [FileType.DOCUMENT]: file_modalities = { "section_count": content.get("section_count", 0), "title": content.get("title", ""), @@ -288,10 +355,10 @@ class MemoryPerceptualService: file_modalities = { "speaker_count": content.get("speaker_count", 0) } - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType.trans_from_file_type(file_type), - file_path=file_url, + perceptual_type=PerceptualType.trans_from_file_type(file.type), + file_path=file.url, file_name=filename, file_ext=file_ext, summary=content.get('summary', ""), @@ -301,3 +368,4 @@ class MemoryPerceptualService: } ) self.db.commit() + return memory diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 6e7c1ad4..58f3e8bd 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -11,9 +11,11 @@ import time from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional +from dotenv import load_dotenv +from sqlalchemy.orm import Session + from app.core.logging_config import get_config_logger, get_logger from app.core.memory.analytics.hot_memory_tags import ( - get_hot_memory_tags, get_raw_tags_from_db, filter_tags_with_llm, ) @@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import ( ) from app.services.memory_config_service import MemoryConfigService from app.utils.sse_utils import format_sse_message -from dotenv import load_dotenv -from sqlalchemy.orm import Session logger = get_logger(__name__) config_logger = get_config_logger() @@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector() class MemoryStorageService: """Service for memory storage operations""" - + def __init__(self): logger.info("MemoryStorageService initialized") - + async def get_storage_info(self) -> dict: """ Example wrapper method - retrieves storage information @@ -59,17 +59,17 @@ class MemoryStorageService: Storage information dictionary """ logger.info("Getting storage info ") - + # Empty wrapper - implement your logic here result = { "status": "active", "message": "This is an example wrapper" } - - return result - -class DataConfigService: # 数据配置服务类(PostgreSQL) + return result + + +class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. 使用 SQLAlchemy ORM 进行数据库操作。 @@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return data_list # --- Create --- - def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) # 业务层检查同一工作空间下是否已存在同名配置 if params.workspace_id and params.config_name: from app.models.memory_config_model import MemoryConfig @@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return None # --- Delete --- - def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) + def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) if not success: raise ValueError("未找到配置") return {"affected": 1} # --- Update --- - def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 + def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 config = MemoryConfigRepository.update(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} - def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 + def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 config = MemoryConfigRepository.update_extracted(self.db, update) if not config: raise ValueError("未找到配置") @@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # --- Read --- - def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 + def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) if not result: raise ValueError("未找到配置") return result # --- Read All --- - def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 + def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) # 检查并修正 pruning_scene 与 scene_name 不一致的记录 @@ -241,13 +241,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) except (ValueError, TypeError): config_id_old = None - - if config_id_old: - memory_config=config_id_old - else: - memory_config=config.config_id config_dict = { - "config_id": memory_config, + "config_id": str(config.config_id), "config_name": config.config_name, "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, @@ -289,7 +284,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 return self._convert_timestamps_to_format(data_list) - async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: """ 流式执行试运行,产生 SSE 格式的进度事件 @@ -311,14 +305,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) """ from pathlib import Path project_root = str(Path(__file__).resolve().parents[2]) - + try: # 发出初始进度事件 yield format_sse_message("starting", { "message": "开始试运行...", "time": int(time.time() * 1000) }) - + # 步骤 1: 配置加载和验证(数据库优先) payload_cid = str(getattr(payload, "config_id", "") or "").strip() cid: Optional[str] = payload_cid if payload_cid else None @@ -344,27 +338,28 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 关联了本体场景,优先使用 custom_text if hasattr(payload, 'custom_text') and payload.custom_text: dialogue_text = payload.custom_text.strip() - logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") + logger.info( + f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") else: # 如果没有提供 custom_text,回退到 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") + logger.info( + f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") else: # 没有关联本体场景,使用 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}") - + # 验证最终使用的文本不为空 if not dialogue_text: raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)") - - logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") + logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") # 步骤 2: 创建进度回调函数捕获管线进度 # 使用队列在回调和生成器之间传递进度事件 progress_queue: asyncio.Queue = asyncio.Queue() - + async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: """ 进度回调函数,将进度事件放入队列 @@ -375,14 +370,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) data: 可选的结果数据(用于传递节点执行结果) """ await progress_queue.put((stage, message, data)) - + # 步骤 3: 在后台任务中执行管线 async def run_pipeline(): """在后台执行管线并捕获异常""" try: from app.services.pilot_run_service import run_pilot_extraction - - logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") + + logger.info( + f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") await run_pilot_extraction( memory_config=memory_config, dialogue_text=dialogue_text, @@ -391,60 +387,60 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) language=language, ) logger.info("[PILOT_RUN_STREAM] pipeline_main completed") - + # 标记管线完成 await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) except Exception as e: # 将异常放入队列 await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) - + # 启动后台任务 pipeline_task = asyncio.create_task(run_pipeline()) - + # 步骤 4: 从队列中读取进度事件并发出 while True: try: # 等待进度事件,设置超时以检测客户端断开 stage, message, data = await asyncio.wait_for( - progress_queue.get(), + progress_queue.get(), timeout=0.5 ) - + # 检查特殊标记 if stage == "__PIPELINE_COMPLETE__": break elif stage == "__PIPELINE_ERROR__": raise RuntimeError(message) - + # 构建进度事件数据 progress_data = { "message": message, "time": int(time.time() * 1000) } - + # 如果有结果数据,添加到事件中 if data: progress_data["data"] = data - + # 发出进度事件,使用 stage 作为事件类型 yield format_sse_message(stage, progress_data) - + except TimeoutError: # 超时,继续等待(这允许检测客户端断开) continue - + # 等待管线任务完成 await pipeline_task - + # 步骤 5: 读取提取结果 from app.core.config import settings result_path = settings.get_memory_output_path("extracted_result.json") if not os.path.isfile(result_path): raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - + with open(result_path, "r", encoding="utf-8") as rf: extracted_result = json.load(rf) - + # 步骤 6: 计算本体覆盖率并合并到结果中 result_data = { "config_id": cid, @@ -460,15 +456,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) result_data["ontology_coverage"] = ontology_coverage except Exception as cov_err: logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) - + yield format_sse_message("result", result_data) - + # 步骤 7: 发出完成事件 yield format_sse_message("done", { "message": "试运行完成", "time": int(time.time() * 1000) }) - + except asyncio.CancelledError: # 客户端断开连接 logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") @@ -483,11 +479,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "time": int(time.time() * 1000) }) - async def _compute_ontology_coverage( - self, - extracted_result: Dict[str, Any], - memory_config, + self, + extracted_result: Dict[str, Any], + memory_config, ) -> Optional[Dict[str, Any]]: """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 @@ -580,8 +575,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # Ensure env for connector (e.g., NEO4J_PASSWORD) -load_dotenv() -_neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: @@ -664,7 +657,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A # 检查结果是否为空或长度不足 if not result or len(result) < 4: data = { - "total": 0, + "total": 0, "distribution": [ {"type": "dialogue", "count": 0}, {"type": "chunk", "count": 0}, @@ -701,10 +694,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] ) return result + async def analytics_hot_memory_tags( - db: Session, - current_user: User, - limit: int = 10 + db: Session, + current_user: User, + limit: int = 10 ) -> List[Dict[str, Any]]: """ 获取热门记忆标签,按数量排序并返回前N个 @@ -721,27 +715,27 @@ async def analytics_hot_memory_tags( from app.services.memory_dashboard_service import get_workspace_end_users # 使用 asyncio.to_thread 避免阻塞事件循环 end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) - + if not end_users: return [] - + # 步骤1: 收集所有用户的原始标签(不调用LLM) connector = Neo4jConnector() try: all_raw_tags = [] for end_user in end_users: raw_tags = await get_raw_tags_from_db( - connector, - str(end_user.id), - limit=raw_limit, + connector, + str(end_user.id), + limit=raw_limit, by_user=False ) if raw_tags: all_raw_tags.extend(raw_tags) - + if not all_raw_tags: return [] - + # 步骤2: 聚合相同标签的频率 tag_frequency_map = {} for tag_name, frequency in all_raw_tags: @@ -749,36 +743,36 @@ async def analytics_hot_memory_tags( tag_frequency_map[tag_name] += frequency else: tag_frequency_map[tag_name] = frequency - + # 步骤3: 按频率降序排序,取前raw_limit个 sorted_tags = sorted( - tag_frequency_map.items(), - key=lambda x: x[1], + tag_frequency_map.items(), + key=lambda x: x[1], reverse=True )[:raw_limit] - + if not sorted_tags: return [] - + # 步骤4: 只调用一次LLM进行筛选 tag_names = [tag for tag, _ in sorted_tags] - + # 使用第一个用户的end_user_id来获取LLM配置 # 因为同一工作空间下的用户应该使用相同的配置 first_end_user_id = str(end_users[0].id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) - + # 步骤5: 根据LLM筛选结果构建最终列表(保留频率) final_tags = [] for tag, freq in sorted_tags: if tag in filtered_tag_names: final_tags.append((tag, freq)) - + # 步骤6: 只返回前limit个 top_tags = final_tags[:limit] - + return [{"name": t, "frequency": f} for t, f in top_tags] - + finally: await connector.close() @@ -815,11 +809,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> source = "log" total = ( - stats.get("chunk_count", 0) - + stats.get("statements_count", 0) - + stats.get("triplet_entities_count", 0) - + stats.get("triplet_relations_count", 0) - + stats.get("temporal_count", 0) + stats.get("chunk_count", 0) + + stats.get("statements_count", 0) + + stats.get("triplet_entities_count", 0) + + stats.get("triplet_relations_count", 0) + + stats.get("temporal_count", 0) ) # 计算"最新一次活动多久前"(仅日志来源时有效) @@ -845,5 +839,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data - - diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index a7398504..b98674ba 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -154,10 +154,17 @@ class ModelConfigService: } elif model_type_lower == "embedding": - # Embedding 模型验证(在线程中运行同步方法) + # Embedding 模型验证 + # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + + # 火山引擎使用 embed_batch,其他使用 embed_documents + if provider.lower() == "volcano": + vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) + else: + vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + elapsed_time = time.time() - start_time return { @@ -193,6 +200,56 @@ class ModelConfigService: }, "error": None } + + elif model_type_lower == "image": + # 图片生成模型验证 + from app.core.models.generation import RedBearImageGenerator + + generator = RedBearImageGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda", + size="2K" + ) + elapsed_time = time.time() - start_time + logger.info(f"成功生成图片,结果: {result}") + + return { + "valid": True, + "message": "图片生成模型配置验证成功", + "response": f"成功生成图片,结果: {result}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda"), + "image_count": 1 + }, + "error": None + } + + elif model_type_lower == "video": + # 视频生成模型验证 + from app.core.models.generation import RedBearVideoGenerator + + generator = RedBearVideoGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda playing in bamboo forest", + duration=5 + ) + elapsed_time = time.time() - start_time + + # 视频生成是异步任务,返回任务ID + task_id = result.get("task_id") if isinstance(result, dict) else None + + return { + "valid": True, + "message": "视频生成模型配置验证成功", + "response": f"成功创建视频生成任务,任务ID: {task_id}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda playing in bamboo forest"), + "task_id": task_id + }, + "error": None + } else: return { diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 6cb0a7f0..4cf3d89d 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,17 +9,15 @@ - OpenAI: 支持 URL 和 base64 格式 """ import base64 +import csv import io -import uuid +import json import zipfile -import chardet from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional -import csv -import json - import PyPDF2 +import chardet import httpx import magic import openpyxl @@ -35,7 +33,6 @@ from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService -from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -297,6 +294,7 @@ PROVIDER_STRATEGIES = { "bedrock": BedrockFormatStrategy, "anthropic": BedrockFormatStrategy, "openai": OpenAIFormatStrategy, + "volcano": OpenAIFormatStrategy, } @@ -342,92 +340,14 @@ class MultimodalService: async def process_files( self, - end_user_id: uuid.UUID | str, files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: - end_user_id: 用户ID files: 文件输入列表 - Returns: - List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式) - """ - if not files: - return [] - if isinstance(end_user_id, uuid.UUID): - end_user_id = str(end_user_id) - - # 获取对应的策略 - # dashscope 的 omni 模型使用 OpenAI 兼容格式 - if self.provider == "dashscope" and self.is_omni: - strategy_class = OpenAIFormatStrategy - else: - strategy_class = PROVIDER_STRATEGIES.get(self.provider) - if not strategy_class: - logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") - strategy_class = DashScopeFormatStrategy - - result = [] - for idx, file in enumerate(files): - strategy = strategy_class(file) - if not file.url: - file.url = await self.get_file_url(file) - try: - if file.type == FileType.IMAGE and "vision" in self.capability: - is_support, content = await self._process_image(file, strategy) - result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) - elif file.type == FileType.DOCUMENT: - is_support, content = await self._process_document(file, strategy) - result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) - elif file.type == FileType.AUDIO and "audio" in self.capability: - is_support, content = await self._process_audio(file, strategy) - result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) - elif file.type == FileType.VIDEO and "video" in self.capability: - is_support, content = await self._process_video(file, strategy) - result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) - else: - logger.warning(f"不支持的文件类型: {file.type}") - except Exception as e: - logger.error( - f"处理文件失败", - extra={ - "file_index": idx, - "file_type": file.type, - "error": str(e) - }, - exc_info=True - ) - # 继续处理其他文件,不中断整个流程 - result.append({ - "type": "text", - "text": f"[文件处理失败: {str(e)}]" - }) - - logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") - return result - - async def history_process_files( - self, - files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: - """ - 处理文件列表,返回 LLM 可用的格式 - - Args: - files: 文件输入列表 - Returns: List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式) """ @@ -483,17 +403,6 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - def write_perceptual_memory( - self, - end_user_id: str, - file_type: str, - file_url: str, - file_message: dict - ): - """写入感知记忆""" - if end_user_id and self.api_config: - write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) - async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index fc749157..4617946b 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -297,9 +297,12 @@ async def run_pilot_extraction( chunk_nodes, statement_nodes, entity_nodes, + _, statement_chunk_edges, statement_entity_edges, entity_edges, + _, + _ ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 12e0c324..585fdd78 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1887,7 +1887,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ "Chunk": ["content", "created_at"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], - "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段 + "Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"] } # 获取该节点类型的白名单字段 diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index e23b1ac3..3122d282 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User: business_logger.info(f"创建用户: {user.username}, email: {user.email}") try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={"username": user.username, "email": user.email} - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: @@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User: ) try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={ - "username": user.username, - "email": user.email, - "created_by": str(current_user.id) - } - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: @@ -276,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: } ) + # 检查是否为租户联系人 + from app.models.tenant_model import Tenants + tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first() + if tenant and tenant.contact_email and tenant.contact_email == db_user.email: + business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}") + raise BusinessException( + "该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员", + code=BizCode.FORBIDDEN, + context={ + "user_id": str(user_id_to_deactivate), + "tenant_id": str(db_user.tenant_id) + } + ) + # 停用用户 business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})") db_user.is_active = False diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py index 2b36c5ea..fd8f25f3 100644 --- a/api/app/services/workflow_import_service.py +++ b/api/app/services/workflow_import_service.py @@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get from app.core.config import settings from app.core.exceptions import BusinessException from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult -from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.schemas import AppCreate from app.schemas.workflow_schema import WorkflowConfigCreate @@ -46,7 +46,7 @@ class WorkflowImportService: success=False, temp_id=None, workflow_id=None, - errors=[UnsupportPlatform(platform=platform)] + errors=[UnsupportedPlatform(platform=platform)] ) adapter = self.registry.get_adapter(platform, config) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index aee3d75f..c7d7f2b1 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories import knowledge_repository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService +from app.services.workspace_service import get_workspace_storage_type_without_auth logger = logging.getLogger(__name__) @@ -540,6 +542,25 @@ class WorkflowService: mapped = internal_event return mapped + def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: + storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) + user_rag_memory_id = "" + if storage_type == "rag": + knowledge = knowledge_repository.get_knowledge_by_name( + db=self.db, + name="USER_RAG_MERORY", + workspace_id=workspace_id + ) + if knowledge: + user_rag_memory_id = str(knowledge.id) + else: + logger.warning( + f"No knowledge base named 'USER_RAG_MEMORY' found, " + f"workspace_id: {workspace_id}, will use neo4j storage" + ) + storage_type = 'neo4j' + return storage_type, user_rag_memory_id + # ==================== 工作流执行 ==================== async def run( @@ -607,6 +628,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files message_id = uuid.uuid4() # 更新状态为运行中 @@ -631,7 +653,9 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ) # 更新执行结果 if result.get("status") == "completed": @@ -780,6 +804,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -801,6 +826,8 @@ class WorkflowService: execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index cefb8380..90b5cf65 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -863,7 +863,7 @@ def get_workspace_storage_type( def get_workspace_storage_type_without_auth( db: Session, workspace_id: uuid.UUID, -) -> Optional[str]: +) -> str: """获取工作空间的存储类型(无需权限验证,用于公开分享等场景) Args: diff --git a/api/app/tasks.py b/api/app/tasks.py index 3a237d82..61736275 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -36,9 +36,11 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models import Document, File, Knowledge +from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema from app.schemas.model_schema import ModelInfo -from app.services.memory_agent_service import MemoryAgentService +from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config +from app.services.memory_forget_service import MemoryForgetService from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id from app.utils.redis_lock import RedisLock @@ -1073,9 +1075,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, - user_rag_memory_id: str, - language: str = "zh") -> Dict[str, Any]: +def write_message_task( + self, + end_user_id: str, + message: list[dict], + config_id: str | int, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" +) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) @@ -1091,7 +1099,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s Raises: Exception on failure """ - logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"config_id={config_id} (type: {type(config_id).__name__}), " @@ -1105,14 +1112,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100 * '-') - print(actual_config_id) - print(100 * '-') - logger.info( - f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} " + f"(type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error( - f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} " + f"(type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1151,8 +1155,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - logger.info( - f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + logger.info(f"[CELERY WRITE] Task completed successfully " + f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: @@ -1167,7 +1171,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s ) except Exception as _e: logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") - return { "status": "SUCCESS", "result": result, @@ -1859,7 +1862,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: @celery_app.task( name="app.tasks.run_forgetting_cycle_task", bind=True, - ignore_result=True, + ignore_result=False, # 改为 False 以便在 Flower 中查看结果 max_retries=0, acks_late=False, time_limit=7200, @@ -1867,68 +1870,77 @@ def workspace_reflection_task(self) -> Dict[str, Any]: ) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 - - 定期执行遗忘周期,识别并融合低激活值的知识节点。 - - Args: - config_id: 配置ID(可选,如果为None则使用默认配置) - - Returns: - 包含任务执行结果的字典 + + 遍历所有终端用户,执行遗忘周期。 """ start_time = time.time() - async def _run() -> Dict[str, Any]: - from app.services.memory_forget_service import MemoryForgetService - + async def _process_users() -> Dict[str, Any]: with get_db_context() as db: - try: - logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + end_users = db.query(EndUser).all() + if not end_users: + logger.info("没有终端用户,跳过遗忘周期") + return {"status": "SUCCESS", "message": "没有终端用户", + "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, + "duration_seconds": time.time() - start_time} - forget_service = MemoryForgetService() + logger.info(f"开始处理 {len(end_users)} 个终端用户的遗忘周期") + forget_service = MemoryForgetService() + total_merged = total_failed = processed_users = 0 + failed_users = [] - # 运行遗忘周期 - # FIXME: MemeoryForgetService - report = await forget_service.trigger_forgetting( - db=db, - end_user_id=None, # 处理所有组 - config_id=config_id - ) + for end_user in end_users: + try: + # 获取用户配置(自动回退到工作空间默认配置) + connected_config = get_end_user_connected_config(str(end_user.id), db) + user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) + + if not user_config_id: + failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) + continue - duration = time.time() - start_time + # 执行遗忘周期 + report = await forget_service.trigger_forgetting_cycle( + db=db, end_user_id=str(end_user.id), config_id=user_config_id + ) + + total_merged += report.get('merged_count', 0) + total_failed += report.get('failed_count', 0) + processed_users += 1 + + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") + + except Exception as e: + logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) + failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) - logger.info( - f"遗忘周期定时任务完成: " - f"融合 {report['merged_count']} 对节点, " - f"失败 {report['failed_count']} 对, " - f"耗时 {duration:.2f} 秒" - ) + duration = time.time() - start_time + logger.info(f"遗忘周期完成: {processed_users}/{len(end_users)} 用户, " + f"融合 {total_merged} 对, 耗时 {duration:.2f}s") - return { - "status": "SUCCESS", - "message": "遗忘周期执行成功", - "report": report, - "duration_seconds": duration - } - - except Exception as e: - duration = time.time() - start_time - logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) - - return { - "status": "FAILED", - "message": f"遗忘周期执行失败: {str(e)}", - "duration_seconds": duration - } + return { + "status": "SUCCESS", + "message": f"处理 {processed_users} 个用户", + "report": { + "merged_count": total_merged, + "failed_count": total_failed, + "processed_users": processed_users, + "total_users": len(end_users), + "failed_users": failed_users + }, + "duration_seconds": duration + } # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_process_users()) + except Exception as e: + logger.error(f"遗忘周期任务失败: {e}", exc_info=True) + return { + "status": "FAILED", + "message": f"任务失败: {str(e)}", + "duration_seconds": time.time() - start_time + } # ============================================================================= @@ -2611,57 +2623,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } -@celery_app.task( - name="app.tasks.write_perceptual_memory", - bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=3600, - soft_time_limit=3300, -) -def write_perceptual_memory( - self, - end_user_id: str, - model_api_config: dict, - file_type: str, - file_url: str, - file_message: dict -): - """ - Write perceptual memory for a user into PostgreSQL and Neo4j. - - This task generates or updates the user's perceptual memory - in the backend databases. It is intended to be executed asynchronously - via Celery. - - Args: - end_user_id (uuid.UUID): The unique identifier of the end user. - model_api_config (ModelInfo): API configuration for the model - used to generate perceptual memory. - file_type (str): The file type - file_url (url): The url of file - file_message (dict): The file message containing details about the file - to be processed. - - Returns: - None - """ - file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() - set_asyncio_event_loop() - with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): - model_info = ModelInfo(**model_api_config) - with get_db_context() as db: - memory_perceptual_service = MemoryPerceptualService(db) - return asyncio.run(memory_perceptual_service.generate_perceptual_memory( - end_user_id, - model_info, - file_type, - file_url, - file_message, - )) - - # ============================================================================= # 社区聚类补全任务(触发型) # ============================================================================= @@ -2672,7 +2633,7 @@ def write_perceptual_memory( ignore_result=False, max_retries=0, acks_late=False, - time_limit=7200, # 2小时硬超时 + time_limit=7200, # 2小时硬超时 soft_time_limit=6900, ) def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]: @@ -2760,7 +2721,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace patch_fail = 0 for cid in incomplete_ids: try: - await engine._generate_community_metadata(cid, end_user_id) + await engine._generate_community_metadata([cid], end_user_id) patch_ok += 1 except Exception as patch_err: patch_fail += 1 @@ -2787,7 +2748,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace embedding_model_id=embedding_model_id, ) - logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") await engine.full_clustering(end_user_id) initialized += 1 logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") @@ -2810,12 +2772,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time diff --git a/api/migrations/versions/05a681a6ca93_202603231611.py b/api/migrations/versions/05a681a6ca93_202603231611.py new file mode 100644 index 00000000..5ab9c4de --- /dev/null +++ b/api/migrations/versions/05a681a6ca93_202603231611.py @@ -0,0 +1,32 @@ +"""202603231611 + +Revision ID: 05a681a6ca93 +Revises: 74b51dfece29 +Create Date: 2026-03-23 16:12:44.110292 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '05a681a6ca93' +down_revision: Union[str, None] = '74b51dfece29' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + # ### end Alembic commands ### diff --git a/api/migrations/versions/1ea8fe97b5b7_202603252115.py b/api/migrations/versions/1ea8fe97b5b7_202603252115.py new file mode 100644 index 00000000..1f0df3e7 --- /dev/null +++ b/api/migrations/versions/1ea8fe97b5b7_202603252115.py @@ -0,0 +1,42 @@ +"""202603252115 + +Revision ID: 1ea8fe97b5b7 +Revises: e28bcc212da5 +Create Date: 2026-03-25 21:14:41.825048 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1ea8fe97b5b7' +down_revision: Union[str, None] = 'e28bcc212da5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True)) + op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True)) + op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'status') + op.drop_column('tenants', 'api_ops_rate_limit') + op.drop_column('tenants', 'plan_expired_at') + op.drop_column('tenants', 'plan') + op.drop_column('tenants', 'contact_phone') + op.drop_column('tenants', 'contact_email') + op.drop_column('tenants', 'contact_name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/adaefcbe2aa1_202603261630.py b/api/migrations/versions/adaefcbe2aa1_202603261630.py new file mode 100644 index 00000000..b8235dd7 --- /dev/null +++ b/api/migrations/versions/adaefcbe2aa1_202603261630.py @@ -0,0 +1,32 @@ +"""202603261630 + +Revision ID: adaefcbe2aa1 +Revises: 1ea8fe97b5b7 +Create Date: 2026-03-26 16:27:17.590077 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'adaefcbe2aa1' +down_revision: Union[str, None] = '1ea8fe97b5b7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('feature_billing', sa.Boolean(), server_default='false', nullable=False, comment='是否启用收费管理菜单')) + op.add_column('tenants', sa.Column('feature_user_management', sa.Boolean(), server_default='false', nullable=False, comment='是否启用用户管理菜单')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'feature_user_management') + op.drop_column('tenants', 'feature_billing') + # ### end Alembic commands ### diff --git a/api/migrations/versions/e28bcc212da5_202603241530.py b/api/migrations/versions/e28bcc212da5_202603241530.py new file mode 100644 index 00000000..00173522 --- /dev/null +++ b/api/migrations/versions/e28bcc212da5_202603241530.py @@ -0,0 +1,34 @@ +"""202603241530 + +Revision ID: e28bcc212da5 +Revises: 05a681a6ca93 +Create Date: 2026-03-24 15:32:14.461480 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e28bcc212da5' +down_revision: Union[str, None] = '05a681a6ca93' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('memory_config', sa.Column('vision_id', sa.String(), nullable=True, comment='视觉模型配置ID')) + op.add_column('memory_config', sa.Column('audio_id', sa.String(), nullable=True, comment='语音模型配置ID')) + op.add_column('memory_config', sa.Column('video_id', sa.String(), nullable=True, comment='视频模型配置ID')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('memory_config', 'video_id') + op.drop_column('memory_config', 'audio_id') + op.drop_column('memory_config', 'vision_id') + # ### end Alembic commands ### diff --git a/api/pyproject.toml b/api/pyproject.toml index e6fddea8..8ced574c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -147,6 +147,7 @@ dependencies = [ "modelscope>=1.34.0", "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", "python-magic-bin>=0.4.14; sys_platform=='win32'", + "volcengine-python-sdk[ark]==5.0.19" ] [tool.pytest.ini_options] diff --git a/web/package.json b/web/package.json index db6a8408..0284f397 100644 --- a/web/package.json +++ b/web/package.json @@ -30,7 +30,7 @@ "@lexical/list": "^0.39.0", "@lexical/react": "^0.39.0", "@lexical/rich-text": "^0.39.0", - "antd": "^5.27.4", + "antd": "^5.29.2", "axios": "^1.12.2", "clsx": "^2.1.1", "codemirror": "^6.0.2", diff --git a/web/src/App.tsx b/web/src/App.tsx index 1d298358..a10f9409 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -21,7 +21,6 @@ import { useTranslation } from 'react-i18next'; import { lightTheme } from './styles/antdThemeConfig.ts' import router from './routes'; import { useI18n } from '@/store/locale' -import LayoutBg from '@/components/Layout/LayoutBg' import dayjs from 'dayjs' import 'dayjs/locale/en' import 'dayjs/locale/zh-cn' @@ -61,7 +60,6 @@ function App() { theme={lightTheme} > - }> { } }) } +// Get workspace API call statistics +export const getWorkspaceApiStatistics = (data: { start_date: number; end_date: number; }) => { + return request.get(`/apps/workspace/api-statistics`, data) +} // Export application export const appExport = (app_id: string, appName: string, data?: { release_id: string }) => { return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data) @@ -165,4 +169,9 @@ export const cancelShare = (app_id: string, target_workspace_id?: string) => { export const cancelSpaceShare = (target_workspace_id?: string) => { return request.delete(`/apps/share/${target_workspace_id}`) } - +// Application conversation logs +export const getAppLogsUrl = (app_id: string) => `/apps/${app_id}/logs` +// Get full conversation message history +export const getAppLogDetail = (app_id: string, conversation_id: string) => { + return request.get(`/apps/${app_id}/logs/${conversation_id}`) +} \ No newline at end of file diff --git a/web/src/api/fileStorage.ts b/web/src/api/fileStorage.ts index ce133565..83f5b212 100644 --- a/web/src/api/fileStorage.ts +++ b/web/src/api/fileStorage.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 13:59:56 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 16:24:05 + * @Last Modified time: 2026-03-23 18:05:43 */ import { request, API_PREFIX } from '@/utils/request' @@ -32,4 +32,13 @@ export const deleteFile = (fileId: string) => { } export const shareFileUploadUrlWithoutApiPrefix = `/storage/share/files` -export const shareFileUploadUrl = `${API_PREFIX}${shareFileUploadUrlWithoutApiPrefix}` \ No newline at end of file +export const shareFileUploadUrl = `${API_PREFIX}${shareFileUploadUrlWithoutApiPrefix}` + +// Get file info +export const getFileInfoByUrl = (url: string) => { + return request.get('/storage/files/info-by-url', {url}) +} +// Get file status +export const getFileStatusById = (file_id: string) => { + return request.get(`/storage/files/${file_id}/status`) +} \ No newline at end of file diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 9a464893..1ec2d7dc 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 18:35:10 + * @Last Modified time: 2026-03-24 17:48:01 */ import { request } from '@/utils/request' import type { AxiosRequestConfig } from 'axios' @@ -87,12 +87,13 @@ export const getUserSummary = (end_user_id: string) => { export const getNodeStatistics = (end_user_id: string) => { return request.get(`/memory-storage/analytics/node_statistics`, { end_user_id }) } -// Basic information -export const getEndUserProfile = (end_user_id: string) => { - return request.get(`/memory-storage/read_end_user/profile`, { end_user_id }) +// 查询用户别名及信息 +export const getEndUserInfo = (end_user_id: string) => { + return request.get(`/memory-storage/end_user_info`, { end_user_id }) } -export const updatedEndUserProfile = (values: EndUser) => { - return request.post(`/memory-storage/updated_end_user/profile`, values) +// 更新用户别名及信息 +export const updatedEndUserInfo = (values: EndUser) => { + return request.post(`/memory-storage/end_user_info/updated`, values) } // User Memory - Relationship network export const getMemorySearchEdges = (end_user_id: string, config?: AxiosRequestConfig) => { diff --git a/web/src/assets/font/MiSans/MiSans-Bold.woff2 b/web/src/assets/font/MiSans/MiSans-Bold.woff2 new file mode 100644 index 00000000..e4a21bee Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Bold.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Demibold.woff2 b/web/src/assets/font/MiSans/MiSans-Demibold.woff2 new file mode 100644 index 00000000..70205afb Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Demibold.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2 b/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2 new file mode 100644 index 00000000..45d16c98 Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Heavy.woff2 b/web/src/assets/font/MiSans/MiSans-Heavy.woff2 new file mode 100644 index 00000000..09ee22e3 Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Heavy.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Light.woff2 b/web/src/assets/font/MiSans/MiSans-Light.woff2 new file mode 100644 index 00000000..a2bb950b Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Light.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Medium.woff2 b/web/src/assets/font/MiSans/MiSans-Medium.woff2 new file mode 100644 index 00000000..617f7407 Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Medium.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Normal.woff2 b/web/src/assets/font/MiSans/MiSans-Normal.woff2 new file mode 100644 index 00000000..d24e89dd Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Normal.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Regular.woff2 b/web/src/assets/font/MiSans/MiSans-Regular.woff2 new file mode 100644 index 00000000..6a699b50 Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Regular.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Semibold.woff2 b/web/src/assets/font/MiSans/MiSans-Semibold.woff2 new file mode 100644 index 00000000..34f43f7c Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Semibold.woff2 differ diff --git a/web/src/assets/font/MiSans/MiSans-Thin.woff2 b/web/src/assets/font/MiSans/MiSans-Thin.woff2 new file mode 100644 index 00000000..ec8a3b55 Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Thin.woff2 differ diff --git a/web/src/assets/font/MiSans/index.ts b/web/src/assets/font/MiSans/index.ts new file mode 100644 index 00000000..e69de29b diff --git a/web/src/assets/images/CloudUploadOutlined.svg b/web/src/assets/images/CloudUploadOutlined.svg new file mode 100644 index 00000000..86fdf286 --- /dev/null +++ b/web/src/assets/images/CloudUploadOutlined.svg @@ -0,0 +1,26 @@ + + + 编组 12 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/application/arrow_right.svg b/web/src/assets/images/application/arrow_right.svg new file mode 100644 index 00000000..06400efc --- /dev/null +++ b/web/src/assets/images/application/arrow_right.svg @@ -0,0 +1,17 @@ + + + 编组 25 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/application/clean.svg b/web/src/assets/images/application/clean.svg index 5d134404..a728abaa 100644 --- a/web/src/assets/images/application/clean.svg +++ b/web/src/assets/images/application/clean.svg @@ -1,13 +1,15 @@ 编组 11 - - - - - - - + + + + + + + + + diff --git a/web/src/assets/images/application/copy.svg b/web/src/assets/images/application/copy.svg new file mode 100644 index 00000000..1bd47c0b --- /dev/null +++ b/web/src/assets/images/application/copy.svg @@ -0,0 +1,18 @@ + + + 复制 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/application/debuggingEmpty.png b/web/src/assets/images/application/debuggingEmpty.png index 0879d4e3..f5d4ef0d 100644 Binary files a/web/src/assets/images/application/debuggingEmpty.png and b/web/src/assets/images/application/debuggingEmpty.png differ diff --git a/web/src/assets/images/application/model.svg b/web/src/assets/images/application/model.svg index 4d482df5..a93f5771 100644 --- a/web/src/assets/images/application/model.svg +++ b/web/src/assets/images/application/model.svg @@ -1,12 +1,12 @@ -_模型预测 - - - - - - + + + + + + diff --git a/web/src/assets/images/application/model_hover.svg b/web/src/assets/images/application/model_hover.svg deleted file mode 100644 index 04e25219..00000000 --- a/web/src/assets/images/application/model_hover.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - -_模型预测 - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/application/save.svg b/web/src/assets/images/application/save.svg new file mode 100644 index 00000000..02dbf635 --- /dev/null +++ b/web/src/assets/images/application/save.svg @@ -0,0 +1,19 @@ + + + 保存 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/application/set.svg b/web/src/assets/images/application/set.svg new file mode 100644 index 00000000..797d2bad --- /dev/null +++ b/web/src/assets/images/application/set.svg @@ -0,0 +1,15 @@ + + + 设置-灰 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/close.svg b/web/src/assets/images/close.svg index cba672fc..d6e1a9b4 100644 --- a/web/src/assets/images/close.svg +++ b/web/src/assets/images/close.svg @@ -1,11 +1,11 @@ 关闭 - - - - - + + + + + diff --git a/web/src/assets/images/common/arrow_right_dark.svg b/web/src/assets/images/common/arrow_right_dark.svg new file mode 100644 index 00000000..b20a440c --- /dev/null +++ b/web/src/assets/images/common/arrow_right_dark.svg @@ -0,0 +1,18 @@ + + + 编组 5 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/arrow_up.svg b/web/src/assets/images/common/arrow_up.svg new file mode 100644 index 00000000..a5105d46 --- /dev/null +++ b/web/src/assets/images/common/arrow_up.svg @@ -0,0 +1,14 @@ + + + 下拉 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/caret_right_outlined.svg b/web/src/assets/images/common/caret_right_outlined.svg new file mode 100644 index 00000000..fcb3c68c --- /dev/null +++ b/web/src/assets/images/common/caret_right_outlined.svg @@ -0,0 +1,16 @@ + + + 编组 38 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/check_green.svg b/web/src/assets/images/common/check_green.svg new file mode 100644 index 00000000..a16b1ee2 --- /dev/null +++ b/web/src/assets/images/common/check_green.svg @@ -0,0 +1,20 @@ + + + 完成 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/copy_dark.svg b/web/src/assets/images/common/copy_dark.svg new file mode 100644 index 00000000..faa6fca1 --- /dev/null +++ b/web/src/assets/images/common/copy_dark.svg @@ -0,0 +1,14 @@ + + + 复制 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/dash.svg b/web/src/assets/images/common/dash.svg new file mode 100644 index 00000000..cf9efb7d --- /dev/null +++ b/web/src/assets/images/common/dash.svg @@ -0,0 +1,15 @@ + + + 编组 27@3x + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/delete.svg b/web/src/assets/images/common/delete.svg new file mode 100644 index 00000000..4eb610ed --- /dev/null +++ b/web/src/assets/images/common/delete.svg @@ -0,0 +1,30 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/delete_dark.svg b/web/src/assets/images/common/delete_dark.svg new file mode 100644 index 00000000..cf93cfd6 --- /dev/null +++ b/web/src/assets/images/common/delete_dark.svg @@ -0,0 +1,16 @@ + + + 删除 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/delete_hover.svg b/web/src/assets/images/common/delete_hover.svg new file mode 100644 index 00000000..bf38179b --- /dev/null +++ b/web/src/assets/images/common/delete_hover.svg @@ -0,0 +1,20 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/delete_red.svg b/web/src/assets/images/common/delete_red.svg new file mode 100644 index 00000000..58ad4d41 --- /dev/null +++ b/web/src/assets/images/common/delete_red.svg @@ -0,0 +1,30 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/edit.svg b/web/src/assets/images/common/edit.svg new file mode 100644 index 00000000..cf00d703 --- /dev/null +++ b/web/src/assets/images/common/edit.svg @@ -0,0 +1,27 @@ + + + 编辑 + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/global_outline.svg b/web/src/assets/images/common/global_outline.svg new file mode 100644 index 00000000..86301a0e --- /dev/null +++ b/web/src/assets/images/common/global_outline.svg @@ -0,0 +1,20 @@ + + + 互联网 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/more.svg b/web/src/assets/images/common/more.svg new file mode 100644 index 00000000..0d4d9cd2 --- /dev/null +++ b/web/src/assets/images/common/more.svg @@ -0,0 +1,14 @@ + + + 更多 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/more_hover.svg b/web/src/assets/images/common/more_hover.svg new file mode 100644 index 00000000..04fc6eb5 --- /dev/null +++ b/web/src/assets/images/common/more_hover.svg @@ -0,0 +1,16 @@ + + + 更多 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/plus.svg b/web/src/assets/images/common/plus.svg new file mode 100644 index 00000000..5a2d7b83 --- /dev/null +++ b/web/src/assets/images/common/plus.svg @@ -0,0 +1,11 @@ + + + 形状结合@2x + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/plus_dark.svg b/web/src/assets/images/common/plus_dark.svg new file mode 100644 index 00000000..b0882a02 --- /dev/null +++ b/web/src/assets/images/common/plus_dark.svg @@ -0,0 +1,15 @@ + + + 编组 5 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/plus_grey.svg b/web/src/assets/images/common/plus_grey.svg new file mode 100644 index 00000000..05fb64e3 --- /dev/null +++ b/web/src/assets/images/common/plus_grey.svg @@ -0,0 +1,13 @@ + + + 形状结合@2x + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/question.svg b/web/src/assets/images/common/question.svg new file mode 100644 index 00000000..f8b0fee4 --- /dev/null +++ b/web/src/assets/images/common/question.svg @@ -0,0 +1,15 @@ + + + 问号小 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/return.svg b/web/src/assets/images/common/return.svg new file mode 100644 index 00000000..cb8166c0 --- /dev/null +++ b/web/src/assets/images/common/return.svg @@ -0,0 +1,17 @@ + + + 退出 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/save.svg b/web/src/assets/images/common/save.svg new file mode 100644 index 00000000..5970236d --- /dev/null +++ b/web/src/assets/images/common/save.svg @@ -0,0 +1,19 @@ + + + 保存 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/conversation/audio.svg b/web/src/assets/images/conversation/audio.svg index 57a5ca49..c7c4e1fd 100644 --- a/web/src/assets/images/conversation/audio.svg +++ b/web/src/assets/images/conversation/audio.svg @@ -1,17 +1,28 @@ - - 编组 15 - - - - - - - - - - - + + 语音 + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/audio_ing.svg b/web/src/assets/images/conversation/audio_ing.svg deleted file mode 100644 index 280a1bd9..00000000 --- a/web/src/assets/images/conversation/audio_ing.svg +++ /dev/null @@ -1,21 +0,0 @@ - - - 编组 15 - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/conversation/conversation.svg b/web/src/assets/images/conversation/conversation.svg index 2ebc02fb..a21f34bc 100644 --- a/web/src/assets/images/conversation/conversation.svg +++ b/web/src/assets/images/conversation/conversation.svg @@ -1,11 +1,12 @@ - + 对话 - - - - - + + + + + + diff --git a/web/src/assets/images/conversation/conversationEmpty.svg b/web/src/assets/images/conversation/conversationEmpty.svg index 2b642355..8320fd75 100644 --- a/web/src/assets/images/conversation/conversationEmpty.svg +++ b/web/src/assets/images/conversation/conversationEmpty.svg @@ -1,21 +1,23 @@ 编组 14 - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/deepThinking.svg b/web/src/assets/images/conversation/deepThinking.svg index b7658bf4..58ad411f 100644 --- a/web/src/assets/images/conversation/deepThinking.svg +++ b/web/src/assets/images/conversation/deepThinking.svg @@ -1,16 +1,29 @@ 深度思考 - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/delete.svg b/web/src/assets/images/conversation/delete.svg index 27f1c15f..b46dea12 100644 --- a/web/src/assets/images/conversation/delete.svg +++ b/web/src/assets/images/conversation/delete.svg @@ -5,7 +5,7 @@ - + diff --git a/web/src/assets/images/conversation/exclamation_circle.svg b/web/src/assets/images/conversation/exclamation_circle.svg new file mode 100644 index 00000000..9b96bbce --- /dev/null +++ b/web/src/assets/images/conversation/exclamation_circle.svg @@ -0,0 +1,15 @@ + + + 告警实心 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/conversation/link.svg b/web/src/assets/images/conversation/link.svg index 18031b71..17298c10 100644 --- a/web/src/assets/images/conversation/link.svg +++ b/web/src/assets/images/conversation/link.svg @@ -1,18 +1,26 @@ - - 链接 - - - - - - - - - - - - + + 编组 6 + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/loading.svg b/web/src/assets/images/conversation/loading.svg index 01adc786..7bed9e7f 100644 --- a/web/src/assets/images/conversation/loading.svg +++ b/web/src/assets/images/conversation/loading.svg @@ -1,13 +1,24 @@ - - 编组 5 - - - - - - - + + 编组 14 + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/memoryFunction.svg b/web/src/assets/images/conversation/memoryFunction.svg index f63dc231..d0f3daf8 100644 --- a/web/src/assets/images/conversation/memoryFunction.svg +++ b/web/src/assets/images/conversation/memoryFunction.svg @@ -1,15 +1,26 @@ - brain-2-line - - - - - - - - - + 1 + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/memoryFunctionChecked.svg b/web/src/assets/images/conversation/memoryFunctionChecked.svg index db12f037..cf136428 100644 --- a/web/src/assets/images/conversation/memoryFunctionChecked.svg +++ b/web/src/assets/images/conversation/memoryFunctionChecked.svg @@ -1,14 +1,27 @@ - brain-2-line - - - - - - - - + 1 + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/normalReply.svg b/web/src/assets/images/conversation/normalReply.svg new file mode 100644 index 00000000..19b8c28d --- /dev/null +++ b/web/src/assets/images/conversation/normalReply.svg @@ -0,0 +1,28 @@ + + + 正常 + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/conversation/online.svg b/web/src/assets/images/conversation/online.svg index 0ae567ca..c9c5812b 100644 --- a/web/src/assets/images/conversation/online.svg +++ b/web/src/assets/images/conversation/online.svg @@ -1,17 +1,28 @@ - 互联网 - - - - - - - - - - - + 联网 + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/onlineChecked.svg b/web/src/assets/images/conversation/onlineChecked.svg index 89fd61c4..fdd2b4b2 100644 --- a/web/src/assets/images/conversation/onlineChecked.svg +++ b/web/src/assets/images/conversation/onlineChecked.svg @@ -1,16 +1,29 @@ - 互联网 - - - - - - - - - - + 联网 + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/quickReply.svg b/web/src/assets/images/conversation/quickReply.svg new file mode 100644 index 00000000..9a90ef1c --- /dev/null +++ b/web/src/assets/images/conversation/quickReply.svg @@ -0,0 +1,28 @@ + + + 快速回复 + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/conversation/redbear.png b/web/src/assets/images/conversation/redbear.png new file mode 100644 index 00000000..8fd5e2f6 Binary files /dev/null and b/web/src/assets/images/conversation/redbear.png differ diff --git a/web/src/assets/images/conversation/send.svg b/web/src/assets/images/conversation/send.svg index a44dcc40..5e9f5a21 100644 --- a/web/src/assets/images/conversation/send.svg +++ b/web/src/assets/images/conversation/send.svg @@ -1,15 +1,27 @@ - - 发送 - - - - - - - - - + + 发送-2@2x + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/sendDisabled.svg b/web/src/assets/images/conversation/sendDisabled.svg index bf774bfd..7eb01380 100644 --- a/web/src/assets/images/conversation/sendDisabled.svg +++ b/web/src/assets/images/conversation/sendDisabled.svg @@ -1,16 +1,26 @@ - - 发送-2 - - - - - - - - - - + + 发送-2@2x + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/variables.svg b/web/src/assets/images/conversation/variables.svg new file mode 100644 index 00000000..e95c6922 --- /dev/null +++ b/web/src/assets/images/conversation/variables.svg @@ -0,0 +1,31 @@ + + + 变量 (1) + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/copy_active.svg b/web/src/assets/images/copy_active.svg index 29a0f520..27f3c265 100644 --- a/web/src/assets/images/copy_active.svg +++ b/web/src/assets/images/copy_active.svg @@ -2,7 +2,7 @@ 复制 - + diff --git a/web/src/assets/images/deleteBg.svg b/web/src/assets/images/deleteBg.svg index 47deed9a..90409bdb 100644 --- a/web/src/assets/images/deleteBg.svg +++ b/web/src/assets/images/deleteBg.svg @@ -1,13 +1,13 @@ 编组 8 - - - - - - - + + + + + + + diff --git a/web/src/assets/images/deleteBorder.svg b/web/src/assets/images/deleteBorder.svg index 6e90bf4a..62b7bf96 100644 --- a/web/src/assets/images/deleteBorder.svg +++ b/web/src/assets/images/deleteBorder.svg @@ -1,12 +1,12 @@ 编组 8 - - - - - - + + + + + + diff --git a/web/src/assets/images/edit.svg b/web/src/assets/images/edit.svg index 67b90d2b..f503f005 100644 --- a/web/src/assets/images/edit.svg +++ b/web/src/assets/images/edit.svg @@ -2,7 +2,7 @@ 编辑 - + diff --git a/web/src/assets/images/editBg.svg b/web/src/assets/images/editBg.svg index 54ce218f..cfdaceef 100644 --- a/web/src/assets/images/editBg.svg +++ b/web/src/assets/images/editBg.svg @@ -1,13 +1,13 @@ 编组 13 - - - - - - - + + + + + + + diff --git a/web/src/assets/images/editBorder.svg b/web/src/assets/images/editBorder.svg index 6a0bd89f..4f6b0762 100644 --- a/web/src/assets/images/editBorder.svg +++ b/web/src/assets/images/editBorder.svg @@ -1,12 +1,12 @@ 编组 13 - - - - - - + + + + + + diff --git a/web/src/assets/images/edit_active.svg b/web/src/assets/images/edit_active.svg new file mode 100644 index 00000000..7beb376b --- /dev/null +++ b/web/src/assets/images/edit_active.svg @@ -0,0 +1,14 @@ + + + 编辑 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/edit_hover.svg b/web/src/assets/images/edit_hover.svg index 6cb4e043..b69ed65a 100644 --- a/web/src/assets/images/edit_hover.svg +++ b/web/src/assets/images/edit_hover.svg @@ -2,7 +2,7 @@ 编辑 - + diff --git a/web/src/assets/images/empty/noData.png b/web/src/assets/images/empty/noData.png new file mode 100644 index 00000000..5258d466 Binary files /dev/null and b/web/src/assets/images/empty/noData.png differ diff --git a/web/src/assets/images/home/application.svg b/web/src/assets/images/home/application.svg new file mode 100644 index 00000000..65ce7ccd --- /dev/null +++ b/web/src/assets/images/home/application.svg @@ -0,0 +1,20 @@ + + + icon_应用管理 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/home/arrow_top_right.svg b/web/src/assets/images/home/arrow_top_right.svg deleted file mode 100644 index fe969a19..00000000 --- a/web/src/assets/images/home/arrow_top_right.svg +++ /dev/null @@ -1,16 +0,0 @@ - - - 编组 16 - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/home/arrow_top_right_hover.svg b/web/src/assets/images/home/arrow_top_right_hover.svg deleted file mode 100644 index 903f9618..00000000 --- a/web/src/assets/images/home/arrow_top_right_hover.svg +++ /dev/null @@ -1,16 +0,0 @@ - - - 编组 16 - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/home/arrow_up.svg b/web/src/assets/images/home/arrow_up.svg new file mode 100644 index 00000000..914cb156 --- /dev/null +++ b/web/src/assets/images/home/arrow_up.svg @@ -0,0 +1,15 @@ + + + 箭头_向上 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/home/chunk_count.svg b/web/src/assets/images/home/chunk_count.svg index 830dac67..544f3cc6 100644 --- a/web/src/assets/images/home/chunk_count.svg +++ b/web/src/assets/images/home/chunk_count.svg @@ -1,22 +1,16 @@ 编组 32 - - - - - - - - - - - - - - - - + + + + + + + + + + diff --git a/web/src/assets/images/home/knowledge.svg b/web/src/assets/images/home/knowledge.svg new file mode 100644 index 00000000..91624510 --- /dev/null +++ b/web/src/assets/images/home/knowledge.svg @@ -0,0 +1,19 @@ + + + 知识库 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/home/memoryConversation.svg b/web/src/assets/images/home/memoryConversation.svg new file mode 100644 index 00000000..59f74de2 --- /dev/null +++ b/web/src/assets/images/home/memoryConversation.svg @@ -0,0 +1,19 @@ + + + 编组 10 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/home/statements_count.svg b/web/src/assets/images/home/statements_count.svg index a20666d1..6d545356 100644 --- a/web/src/assets/images/home/statements_count.svg +++ b/web/src/assets/images/home/statements_count.svg @@ -1,15 +1,18 @@ 编组 38 - - - - - - - - - + + + + + + + + + + + + diff --git a/web/src/assets/images/home/temporal_count.svg b/web/src/assets/images/home/temporal_count.svg index 050697bc..739acb30 100644 --- a/web/src/assets/images/home/temporal_count.svg +++ b/web/src/assets/images/home/temporal_count.svg @@ -1,17 +1,20 @@ 编组 39 - - - - - - - - - - - + + + + + + + + + + + + + + diff --git a/web/src/assets/images/home/totalMemoryCapacity.png b/web/src/assets/images/home/totalMemoryCapacity.png new file mode 100644 index 00000000..4a58cfad Binary files /dev/null and b/web/src/assets/images/home/totalMemoryCapacity.png differ diff --git a/web/src/assets/images/home/triplet_count.svg b/web/src/assets/images/home/triplet_count.svg index ebcfd0aa..603ede84 100644 --- a/web/src/assets/images/home/triplet_count.svg +++ b/web/src/assets/images/home/triplet_count.svg @@ -1,15 +1,14 @@ 编组 37 - - - - - - - - - + + + + + + + + diff --git a/web/src/assets/images/index/apps.svg b/web/src/assets/images/index/apps.svg index 58907fd6..b49bda51 100644 --- a/web/src/assets/images/index/apps.svg +++ b/web/src/assets/images/index/apps.svg @@ -1,14 +1,14 @@ 编组 34 - - - - + + + + - + diff --git a/web/src/assets/images/index/arrow_down.svg b/web/src/assets/images/index/arrow_down.svg index b77a3f8a..366e5848 100644 --- a/web/src/assets/images/index/arrow_down.svg +++ b/web/src/assets/images/index/arrow_down.svg @@ -1,10 +1,10 @@ 箭头_向上 - - - - + + + + diff --git a/web/src/assets/images/index/arrow_down_d.svg b/web/src/assets/images/index/arrow_down_d.svg index 40e5d94b..7393ca80 100644 --- a/web/src/assets/images/index/arrow_down_d.svg +++ b/web/src/assets/images/index/arrow_down_d.svg @@ -1,10 +1,10 @@ 编组 30 - - - - + + + + diff --git a/web/src/assets/images/index/arrow_up.svg b/web/src/assets/images/index/arrow_up.svg index 62aeee96..8a8bae53 100644 --- a/web/src/assets/images/index/arrow_up.svg +++ b/web/src/assets/images/index/arrow_up.svg @@ -1,10 +1,10 @@ 箭头_向上 - - - - + + + + diff --git a/web/src/assets/images/index/arrow_up_d.svg b/web/src/assets/images/index/arrow_up_d.svg index 3c19fef3..3529a291 100644 --- a/web/src/assets/images/index/arrow_up_d.svg +++ b/web/src/assets/images/index/arrow_up_d.svg @@ -1,10 +1,10 @@ 编组 30 - - - - + + + + diff --git a/web/src/assets/images/index/guide_bg@2x.png b/web/src/assets/images/index/guide_bg@2x.png index 3b7490fb..fbf452e6 100644 Binary files a/web/src/assets/images/index/guide_bg@2x.png and b/web/src/assets/images/index/guide_bg@2x.png differ diff --git a/web/src/assets/images/index/help_center.svg b/web/src/assets/images/index/help_center.svg index 6d272121..28595b0a 100644 --- a/web/src/assets/images/index/help_center.svg +++ b/web/src/assets/images/index/help_center.svg @@ -1,13 +1,13 @@ 编组 17 - - - - - - - + + + + + + + diff --git a/web/src/assets/images/index/index_bg@2x.png b/web/src/assets/images/index/index_bg@2x.png index d20ee4d3..fbf30083 100644 Binary files a/web/src/assets/images/index/index_bg@2x.png and b/web/src/assets/images/index/index_bg@2x.png differ diff --git a/web/src/assets/images/index/model_mgt.svg b/web/src/assets/images/index/model_mgt.svg index 89e13ec3..536f8877 100644 --- a/web/src/assets/images/index/model_mgt.svg +++ b/web/src/assets/images/index/model_mgt.svg @@ -1,26 +1,26 @@ 编组 25 - - - - - - + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/web/src/assets/images/index/models.svg b/web/src/assets/images/index/models.svg index 890f240a..60863681 100644 --- a/web/src/assets/images/index/models.svg +++ b/web/src/assets/images/index/models.svg @@ -1,9 +1,9 @@ 编组 14 - - - + + + diff --git a/web/src/assets/images/index/space_mgt.svg b/web/src/assets/images/index/space_mgt.svg index af1db66c..a71f7431 100644 --- a/web/src/assets/images/index/space_mgt.svg +++ b/web/src/assets/images/index/space_mgt.svg @@ -1,13 +1,13 @@ 编组 26 - - - - - - - + + + + + + + diff --git a/web/src/assets/images/index/spaces.svg b/web/src/assets/images/index/spaces.svg index 1c61bc6b..e79eb113 100644 --- a/web/src/assets/images/index/spaces.svg +++ b/web/src/assets/images/index/spaces.svg @@ -7,10 +7,10 @@ - - - - + + + + diff --git a/web/src/assets/images/index/user_mgt.svg b/web/src/assets/images/index/user_mgt.svg index d53a97b9..4ec237aa 100644 --- a/web/src/assets/images/index/user_mgt.svg +++ b/web/src/assets/images/index/user_mgt.svg @@ -1,13 +1,13 @@ 编组 24 - - - - - - - + + + + + + + diff --git a/web/src/assets/images/index/users.svg b/web/src/assets/images/index/users.svg index 545d9636..bfb37872 100644 --- a/web/src/assets/images/index/users.svg +++ b/web/src/assets/images/index/users.svg @@ -1,10 +1,10 @@ 编组 33 - - - - + + + + diff --git a/web/src/assets/images/memory/arrow_right.svg b/web/src/assets/images/memory/arrow_right.svg index 0d17ec3b..090330e9 100644 --- a/web/src/assets/images/memory/arrow_right.svg +++ b/web/src/assets/images/memory/arrow_right.svg @@ -1,12 +1,12 @@ - 下拉备份 - - - - - - + 下拉 + + + + + + diff --git a/web/src/assets/images/memory/clock_orange.svg b/web/src/assets/images/memory/clock_orange.svg new file mode 100644 index 00000000..5c2b58cf --- /dev/null +++ b/web/src/assets/images/memory/clock_orange.svg @@ -0,0 +1,18 @@ + + + 时间戳 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/memory/debug.svg b/web/src/assets/images/memory/debug.svg new file mode 100644 index 00000000..325a355a --- /dev/null +++ b/web/src/assets/images/memory/debug.svg @@ -0,0 +1,15 @@ + + + 配置管理 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/apiKey.png b/web/src/assets/images/menu/apiKey.png deleted file mode 100644 index 53d19428..00000000 Binary files a/web/src/assets/images/menu/apiKey.png and /dev/null differ diff --git a/web/src/assets/images/menu/apiKey_active.png b/web/src/assets/images/menu/apiKey_active.png deleted file mode 100644 index 4f8d1cfa..00000000 Binary files a/web/src/assets/images/menu/apiKey_active.png and /dev/null differ diff --git a/web/src/assets/images/menu/dashboard.svg b/web/src/assets/images/menu/dashboard.svg deleted file mode 100644 index 43e05b3a..00000000 --- a/web/src/assets/images/menu/dashboard.svg +++ /dev/null @@ -1,18 +0,0 @@ - - - 编组 27 - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/dashboard_active.svg b/web/src/assets/images/menu/dashboard_active.svg deleted file mode 100644 index 3f1bc65c..00000000 --- a/web/src/assets/images/menu/dashboard_active.svg +++ /dev/null @@ -1,18 +0,0 @@ - - - 编组 27 - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/knowledge.svg b/web/src/assets/images/menu/knowledge.svg deleted file mode 100644 index 3fc1ec0f..00000000 --- a/web/src/assets/images/menu/knowledge.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - 知识库 - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/knowledge_active.svg b/web/src/assets/images/menu/knowledge_active.svg deleted file mode 100644 index 9b09bbf4..00000000 --- a/web/src/assets/images/menu/knowledge_active.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - 知识库 - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/member.svg b/web/src/assets/images/menu/member.svg deleted file mode 100644 index 56cca8c1..00000000 --- a/web/src/assets/images/menu/member.svg +++ /dev/null @@ -1,18 +0,0 @@ - - - 用户总数总计 - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/member_active.svg b/web/src/assets/images/menu/member_active.svg deleted file mode 100644 index 30cf9261..00000000 --- a/web/src/assets/images/menu/member_active.svg +++ /dev/null @@ -1,18 +0,0 @@ - - - 用户总数总计 - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/memory.svg b/web/src/assets/images/menu/memory.svg deleted file mode 100644 index 71696861..00000000 --- a/web/src/assets/images/menu/memory.svg +++ /dev/null @@ -1,16 +0,0 @@ - - - brain-2-line - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/memoryConversation.svg b/web/src/assets/images/menu/memoryConversation.svg deleted file mode 100644 index 369cbc5a..00000000 --- a/web/src/assets/images/menu/memoryConversation.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - 编组 10 - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/memoryConversation_active.svg b/web/src/assets/images/menu/memoryConversation_active.svg deleted file mode 100644 index c79a75f6..00000000 --- a/web/src/assets/images/menu/memoryConversation_active.svg +++ /dev/null @@ -1,19 +0,0 @@ - - - 编组 10 - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/memory_active.svg b/web/src/assets/images/menu/memory_active.svg deleted file mode 100644 index eabe9221..00000000 --- a/web/src/assets/images/menu/memory_active.svg +++ /dev/null @@ -1,16 +0,0 @@ - - - brain-2-line - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/model.svg b/web/src/assets/images/menu/model.svg deleted file mode 100644 index bbb7e103..00000000 --- a/web/src/assets/images/menu/model.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - -_模型预测 - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/model_active.svg b/web/src/assets/images/menu/model_active.svg deleted file mode 100644 index 274b146e..00000000 --- a/web/src/assets/images/menu/model_active.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - -_模型预测 - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology.svg b/web/src/assets/images/menu/ontology.svg deleted file mode 100644 index 9bfda42b..00000000 --- a/web/src/assets/images/menu/ontology.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - 本体管理备份 - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology_active.svg b/web/src/assets/images/menu/ontology_active.svg deleted file mode 100644 index 1271c2c3..00000000 --- a/web/src/assets/images/menu/ontology_active.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - 本体管理 - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/pricing.svg b/web/src/assets/images/menu/pricing.svg deleted file mode 100644 index 5510ba23..00000000 --- a/web/src/assets/images/menu/pricing.svg +++ /dev/null @@ -1,22 +0,0 @@ - - - 菜单-收费管理 - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/pricing_active.svg b/web/src/assets/images/menu/pricing_active.svg deleted file mode 100644 index f708877d..00000000 --- a/web/src/assets/images/menu/pricing_active.svg +++ /dev/null @@ -1,22 +0,0 @@ - - - 菜单-收费管理 - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt.svg b/web/src/assets/images/menu/prompt.svg deleted file mode 100644 index ffef9a34..00000000 --- a/web/src/assets/images/menu/prompt.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - 提示词备份 - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt_active.svg b/web/src/assets/images/menu/prompt_active.svg deleted file mode 100644 index ac45e13c..00000000 --- a/web/src/assets/images/menu/prompt_active.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - 提示词 - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/skills.svg b/web/src/assets/images/menu/skills.svg deleted file mode 100644 index ac121d1e..00000000 --- a/web/src/assets/images/menu/skills.svg +++ /dev/null @@ -1,14 +0,0 @@ - - - 技能点 - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/skills_active.svg b/web/src/assets/images/menu/skills_active.svg deleted file mode 100644 index 789b5586..00000000 --- a/web/src/assets/images/menu/skills_active.svg +++ /dev/null @@ -1,14 +0,0 @@ - - - 技能点备份 - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/space.svg b/web/src/assets/images/menu/space.svg deleted file mode 100644 index c82c7922..00000000 --- a/web/src/assets/images/menu/space.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - 模型管理 - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/spaceConfig.svg b/web/src/assets/images/menu/spaceConfig.svg deleted file mode 100644 index bcfeae12..00000000 --- a/web/src/assets/images/menu/spaceConfig.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - 模型 (1) - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/spaceConfig_active.svg b/web/src/assets/images/menu/spaceConfig_active.svg deleted file mode 100644 index 41b25689..00000000 --- a/web/src/assets/images/menu/spaceConfig_active.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - 模型 (1) - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/space_active.svg b/web/src/assets/images/menu/space_active.svg deleted file mode 100644 index 69b1629c..00000000 --- a/web/src/assets/images/menu/space_active.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - 模型管理 - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/tool.png b/web/src/assets/images/menu/tool.png deleted file mode 100644 index 669238e8..00000000 Binary files a/web/src/assets/images/menu/tool.png and /dev/null differ diff --git a/web/src/assets/images/menu/tool_active.png b/web/src/assets/images/menu/tool_active.png deleted file mode 100644 index 252cd702..00000000 Binary files a/web/src/assets/images/menu/tool_active.png and /dev/null differ diff --git a/web/src/assets/images/menu/user.svg b/web/src/assets/images/menu/user.svg deleted file mode 100644 index b1eaf5b9..00000000 --- a/web/src/assets/images/menu/user.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - 138设置、系统设置、功能设置、属性 - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/userMemory1.svg b/web/src/assets/images/menu/userMemory1.svg deleted file mode 100644 index c4b9cd51..00000000 --- a/web/src/assets/images/menu/userMemory1.svg +++ /dev/null @@ -1,18 +0,0 @@ - - - 编组 29 - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menu/user_active.svg b/web/src/assets/images/menu/user_active.svg deleted file mode 100644 index 38de2069..00000000 --- a/web/src/assets/images/menu/user_active.svg +++ /dev/null @@ -1,13 +0,0 @@ - - - 138设置、系统设置、功能设置、属性 - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menuNew/apiKey.svg b/web/src/assets/images/menuNew/apiKey.svg new file mode 100644 index 00000000..c31e2d5c --- /dev/null +++ b/web/src/assets/images/menuNew/apiKey.svg @@ -0,0 +1,13 @@ + + + api + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/apiKey_active.svg b/web/src/assets/images/menuNew/apiKey_active.svg new file mode 100644 index 00000000..7520cb86 --- /dev/null +++ b/web/src/assets/images/menuNew/apiKey_active.svg @@ -0,0 +1,13 @@ + + + api + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/application.svg b/web/src/assets/images/menuNew/application.svg similarity index 65% rename from web/src/assets/images/menu/application.svg rename to web/src/assets/images/menuNew/application.svg index 37967d3a..a8fe8fc0 100644 --- a/web/src/assets/images/menu/application.svg +++ b/web/src/assets/images/menuNew/application.svg @@ -1,11 +1,11 @@ 应用管理 - - - - - + + + + + diff --git a/web/src/assets/images/menu/application_active.svg b/web/src/assets/images/menuNew/application_active.svg similarity index 65% rename from web/src/assets/images/menu/application_active.svg rename to web/src/assets/images/menuNew/application_active.svg index 3fe48200..0d8f91f9 100644 --- a/web/src/assets/images/menu/application_active.svg +++ b/web/src/assets/images/menuNew/application_active.svg @@ -1,11 +1,11 @@ 应用管理 - - - - - + + + + + diff --git a/web/src/assets/images/menuNew/dashboard.svg b/web/src/assets/images/menuNew/dashboard.svg new file mode 100644 index 00000000..d35e35fb --- /dev/null +++ b/web/src/assets/images/menuNew/dashboard.svg @@ -0,0 +1,18 @@ + + + 编组 27 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/dashboard_active.svg b/web/src/assets/images/menuNew/dashboard_active.svg new file mode 100644 index 00000000..4a0f57b6 --- /dev/null +++ b/web/src/assets/images/menuNew/dashboard_active.svg @@ -0,0 +1,18 @@ + + + 编组 27 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/knowledge.svg b/web/src/assets/images/menuNew/knowledge.svg new file mode 100644 index 00000000..2d7a28de --- /dev/null +++ b/web/src/assets/images/menuNew/knowledge.svg @@ -0,0 +1,20 @@ + + + 知识库 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/knowledge_active.svg b/web/src/assets/images/menuNew/knowledge_active.svg new file mode 100644 index 00000000..0a2fba96 --- /dev/null +++ b/web/src/assets/images/menuNew/knowledge_active.svg @@ -0,0 +1,20 @@ + + + 知识库 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/member.svg b/web/src/assets/images/menuNew/member.svg new file mode 100644 index 00000000..35edbe1a --- /dev/null +++ b/web/src/assets/images/menuNew/member.svg @@ -0,0 +1,18 @@ + + + 用户总数总计 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/member_active.svg b/web/src/assets/images/menuNew/member_active.svg new file mode 100644 index 00000000..96269cd5 --- /dev/null +++ b/web/src/assets/images/menuNew/member_active.svg @@ -0,0 +1,18 @@ + + + 用户总数总计 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/memory.svg b/web/src/assets/images/menuNew/memory.svg new file mode 100644 index 00000000..17c8368b --- /dev/null +++ b/web/src/assets/images/menuNew/memory.svg @@ -0,0 +1,16 @@ + + + brain-2-line + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/memoryConversation.svg b/web/src/assets/images/menuNew/memoryConversation.svg new file mode 100644 index 00000000..f74146b0 --- /dev/null +++ b/web/src/assets/images/menuNew/memoryConversation.svg @@ -0,0 +1,13 @@ + + + 对话 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/memoryConversation_active.svg b/web/src/assets/images/menuNew/memoryConversation_active.svg new file mode 100644 index 00000000..c2c4aae3 --- /dev/null +++ b/web/src/assets/images/menuNew/memoryConversation_active.svg @@ -0,0 +1,13 @@ + + + 对话 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/memory_active.svg b/web/src/assets/images/menuNew/memory_active.svg new file mode 100644 index 00000000..3aa5ff94 --- /dev/null +++ b/web/src/assets/images/menuNew/memory_active.svg @@ -0,0 +1,16 @@ + + + brain-2-line + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/menuFold.svg b/web/src/assets/images/menuNew/menuFold.svg new file mode 100644 index 00000000..3350cfc4 --- /dev/null +++ b/web/src/assets/images/menuNew/menuFold.svg @@ -0,0 +1,15 @@ + + + 收起 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/model.svg b/web/src/assets/images/menuNew/model.svg new file mode 100644 index 00000000..8fdc015a --- /dev/null +++ b/web/src/assets/images/menuNew/model.svg @@ -0,0 +1,13 @@ + + + -_模型预测 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/model_active.svg b/web/src/assets/images/menuNew/model_active.svg new file mode 100644 index 00000000..6145f360 --- /dev/null +++ b/web/src/assets/images/menuNew/model_active.svg @@ -0,0 +1,11 @@ + + + -_模型预测 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/ontology.svg b/web/src/assets/images/menuNew/ontology.svg new file mode 100644 index 00000000..68798ccd --- /dev/null +++ b/web/src/assets/images/menuNew/ontology.svg @@ -0,0 +1,17 @@ + + + 本体管理 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/ontology_active.svg b/web/src/assets/images/menuNew/ontology_active.svg new file mode 100644 index 00000000..f05a1069 --- /dev/null +++ b/web/src/assets/images/menuNew/ontology_active.svg @@ -0,0 +1,17 @@ + + + 本体管理 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/pricing.svg b/web/src/assets/images/menuNew/pricing.svg new file mode 100644 index 00000000..8c412ac0 --- /dev/null +++ b/web/src/assets/images/menuNew/pricing.svg @@ -0,0 +1,13 @@ + + + 收费管理 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/pricing_active.svg b/web/src/assets/images/menuNew/pricing_active.svg new file mode 100644 index 00000000..54a0afb4 --- /dev/null +++ b/web/src/assets/images/menuNew/pricing_active.svg @@ -0,0 +1,11 @@ + + + 收费管理 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/prompt.svg b/web/src/assets/images/menuNew/prompt.svg new file mode 100644 index 00000000..8007982b --- /dev/null +++ b/web/src/assets/images/menuNew/prompt.svg @@ -0,0 +1,21 @@ + + + 提示词 + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/prompt_active.svg b/web/src/assets/images/menuNew/prompt_active.svg new file mode 100644 index 00000000..4c94cac2 --- /dev/null +++ b/web/src/assets/images/menuNew/prompt_active.svg @@ -0,0 +1,21 @@ + + + 提示词 + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/skills.svg b/web/src/assets/images/menuNew/skills.svg new file mode 100644 index 00000000..3c8dd525 --- /dev/null +++ b/web/src/assets/images/menuNew/skills.svg @@ -0,0 +1,18 @@ + + + skills-icon + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/skills_active.svg b/web/src/assets/images/menuNew/skills_active.svg new file mode 100644 index 00000000..86191a8a --- /dev/null +++ b/web/src/assets/images/menuNew/skills_active.svg @@ -0,0 +1,16 @@ + + + skills-icon + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/space.svg b/web/src/assets/images/menuNew/space.svg new file mode 100644 index 00000000..d0e7a5e4 --- /dev/null +++ b/web/src/assets/images/menuNew/space.svg @@ -0,0 +1,15 @@ + + + 模型管理 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/spaceConfig.svg b/web/src/assets/images/menuNew/spaceConfig.svg new file mode 100644 index 00000000..f03b2f05 --- /dev/null +++ b/web/src/assets/images/menuNew/spaceConfig.svg @@ -0,0 +1,19 @@ + + + 空间配置 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/spaceConfig_active.svg b/web/src/assets/images/menuNew/spaceConfig_active.svg new file mode 100644 index 00000000..578963a0 --- /dev/null +++ b/web/src/assets/images/menuNew/spaceConfig_active.svg @@ -0,0 +1,19 @@ + + + 空间配置 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/space_active.svg b/web/src/assets/images/menuNew/space_active.svg new file mode 100644 index 00000000..e55efb3e --- /dev/null +++ b/web/src/assets/images/menuNew/space_active.svg @@ -0,0 +1,13 @@ + + + 模型管理 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/tool.svg b/web/src/assets/images/menuNew/tool.svg new file mode 100644 index 00000000..0a14a626 --- /dev/null +++ b/web/src/assets/images/menuNew/tool.svg @@ -0,0 +1,18 @@ + + + 工具管理 (2) + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/tool_active.svg b/web/src/assets/images/menuNew/tool_active.svg new file mode 100644 index 00000000..00544dac --- /dev/null +++ b/web/src/assets/images/menuNew/tool_active.svg @@ -0,0 +1,16 @@ + + + 工具管理 (2) + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/user.svg b/web/src/assets/images/menuNew/user.svg new file mode 100644 index 00000000..d04fb501 --- /dev/null +++ b/web/src/assets/images/menuNew/user.svg @@ -0,0 +1,13 @@ + + + 用户管理 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/userMemory.svg b/web/src/assets/images/menuNew/userMemory.svg similarity index 76% rename from web/src/assets/images/menu/userMemory.svg rename to web/src/assets/images/menuNew/userMemory.svg index c4b9cd51..9eb5b1fc 100644 --- a/web/src/assets/images/menu/userMemory.svg +++ b/web/src/assets/images/menuNew/userMemory.svg @@ -1,11 +1,11 @@ 编组 29 - - - - - + + + + + diff --git a/web/src/assets/images/menu/userMemory_active.svg b/web/src/assets/images/menuNew/userMemory_active.svg similarity index 76% rename from web/src/assets/images/menu/userMemory_active.svg rename to web/src/assets/images/menuNew/userMemory_active.svg index 554dc0bc..d31e4859 100644 --- a/web/src/assets/images/menu/userMemory_active.svg +++ b/web/src/assets/images/menuNew/userMemory_active.svg @@ -1,11 +1,11 @@ 编组 29 - - - - - + + + + + diff --git a/web/src/assets/images/menuNew/user_active.svg b/web/src/assets/images/menuNew/user_active.svg new file mode 100644 index 00000000..33778047 --- /dev/null +++ b/web/src/assets/images/menuNew/user_active.svg @@ -0,0 +1,11 @@ + + + 用户管理 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/model/volcano.png b/web/src/assets/images/model/volcano.png new file mode 100644 index 00000000..9aeb3bf3 Binary files /dev/null and b/web/src/assets/images/model/volcano.png differ diff --git a/web/src/assets/images/question.svg b/web/src/assets/images/question.svg new file mode 100644 index 00000000..539ab03a --- /dev/null +++ b/web/src/assets/images/question.svg @@ -0,0 +1,17 @@ + + + 问号小 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/refresh.svg b/web/src/assets/images/refresh.svg index c592feff..79d2f836 100644 --- a/web/src/assets/images/refresh.svg +++ b/web/src/assets/images/refresh.svg @@ -1,12 +1,16 @@ - - 刷新 - - - - - - + + 编组 28 + + + + + + + + + + diff --git a/web/src/assets/images/refresh_hover.svg b/web/src/assets/images/refresh_dark.svg similarity index 97% rename from web/src/assets/images/refresh_hover.svg rename to web/src/assets/images/refresh_dark.svg index 1d4dcf7c..07864e99 100644 --- a/web/src/assets/images/refresh_hover.svg +++ b/web/src/assets/images/refresh_dark.svg @@ -2,7 +2,7 @@ 刷新 - + diff --git a/web/src/assets/images/tool/market.png b/web/src/assets/images/tool/market.png new file mode 100644 index 00000000..9639e253 Binary files /dev/null and b/web/src/assets/images/tool/market.png differ diff --git a/web/src/assets/images/userMemory/aboutMe.svg b/web/src/assets/images/userMemory/aboutMe.svg new file mode 100644 index 00000000..16630fdb --- /dev/null +++ b/web/src/assets/images/userMemory/aboutMe.svg @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/aboutMe_active.svg b/web/src/assets/images/userMemory/aboutMe_active.svg new file mode 100644 index 00000000..53e6362e --- /dev/null +++ b/web/src/assets/images/userMemory/aboutMe_active.svg @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/aboutUs.svg b/web/src/assets/images/userMemory/aboutUs.svg index 1d75eeae..b8fa9e45 100644 --- a/web/src/assets/images/userMemory/aboutUs.svg +++ b/web/src/assets/images/userMemory/aboutUs.svg @@ -1,13 +1,15 @@ - + - - - - - - - + + + + + + + + + diff --git a/web/src/assets/images/userMemory/ai.png b/web/src/assets/images/userMemory/ai.png new file mode 100644 index 00000000..3783a543 Binary files /dev/null and b/web/src/assets/images/userMemory/ai.png differ diff --git a/web/src/assets/images/userMemory/arrow_right.svg b/web/src/assets/images/userMemory/arrow_right.svg index aca820f8..3fa0eb49 100644 --- a/web/src/assets/images/userMemory/arrow_right.svg +++ b/web/src/assets/images/userMemory/arrow_right.svg @@ -1,12 +1,14 @@ 编组 5 - - - - - - + + + + + + + + diff --git a/web/src/assets/images/userMemory/arrow_right_dark.svg b/web/src/assets/images/userMemory/arrow_right_dark.svg new file mode 100644 index 00000000..38cfd953 --- /dev/null +++ b/web/src/assets/images/userMemory/arrow_right_dark.svg @@ -0,0 +1,16 @@ + + + 编组 5 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/arrow_right_hover.svg b/web/src/assets/images/userMemory/arrow_right_hover.svg index 0fed7c6b..444a7a03 100644 --- a/web/src/assets/images/userMemory/arrow_right_hover.svg +++ b/web/src/assets/images/userMemory/arrow_right_hover.svg @@ -1,12 +1,14 @@ 编组 5 - - - - - - + + + + + + + + diff --git a/web/src/assets/images/userMemory/chat.svg b/web/src/assets/images/userMemory/chat.svg new file mode 100644 index 00000000..11b34345 --- /dev/null +++ b/web/src/assets/images/userMemory/chat.svg @@ -0,0 +1,17 @@ + + + 编组 61 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/close.svg b/web/src/assets/images/userMemory/close.svg new file mode 100644 index 00000000..1b511252 --- /dev/null +++ b/web/src/assets/images/userMemory/close.svg @@ -0,0 +1,13 @@ + + + 关闭 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/down.svg b/web/src/assets/images/userMemory/down.svg index ae263f65..07a70e0d 100644 --- a/web/src/assets/images/userMemory/down.svg +++ b/web/src/assets/images/userMemory/down.svg @@ -1,13 +1,15 @@ - 下拉备份 - - - - - - - + 下拉 + + + + + + + + + diff --git a/web/src/assets/images/userMemory/download.svg b/web/src/assets/images/userMemory/download.svg new file mode 100644 index 00000000..1aa4f1ac --- /dev/null +++ b/web/src/assets/images/userMemory/download.svg @@ -0,0 +1,21 @@ + + + 更多 + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/download_hover.svg b/web/src/assets/images/userMemory/download_hover.svg new file mode 100644 index 00000000..5079a1ff --- /dev/null +++ b/web/src/assets/images/userMemory/download_hover.svg @@ -0,0 +1,22 @@ + + + 更多 + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/entity.svg b/web/src/assets/images/userMemory/entity.svg new file mode 100644 index 00000000..ad6a2692 --- /dev/null +++ b/web/src/assets/images/userMemory/entity.svg @@ -0,0 +1,20 @@ + + + 编组 5 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/file.svg b/web/src/assets/images/userMemory/file.svg new file mode 100644 index 00000000..6bfd562e --- /dev/null +++ b/web/src/assets/images/userMemory/file.svg @@ -0,0 +1,85 @@ + + + 编组 9 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + TEXT + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/forget.png b/web/src/assets/images/userMemory/forget.png new file mode 100644 index 00000000..f38ff9bd Binary files /dev/null and b/web/src/assets/images/userMemory/forget.png differ diff --git a/web/src/assets/images/userMemory/interestDistribution.svg b/web/src/assets/images/userMemory/interestDistribution.svg index f39d3bb3..d26a9952 100644 --- a/web/src/assets/images/userMemory/interestDistribution.svg +++ b/web/src/assets/images/userMemory/interestDistribution.svg @@ -1,16 +1,18 @@ - + 兴趣爱好 - - - - - - - - - - + + + + + + + + + + + + diff --git a/web/src/assets/images/userMemory/interestDistribution_active.svg b/web/src/assets/images/userMemory/interestDistribution_active.svg new file mode 100644 index 00000000..87b8d548 --- /dev/null +++ b/web/src/assets/images/userMemory/interestDistribution_active.svg @@ -0,0 +1,21 @@ + + + 兴趣爱好 + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/logo.png b/web/src/assets/images/userMemory/logo.png new file mode 100644 index 00000000..ab37dda6 Binary files /dev/null and b/web/src/assets/images/userMemory/logo.png differ diff --git a/web/src/assets/images/userMemory/logout.svg b/web/src/assets/images/userMemory/logout.svg new file mode 100644 index 00000000..8c21f4e2 --- /dev/null +++ b/web/src/assets/images/userMemory/logout.svg @@ -0,0 +1,13 @@ + + + 退出 (1) + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/long_term_number.svg b/web/src/assets/images/userMemory/long_term_number.svg new file mode 100644 index 00000000..134af714 --- /dev/null +++ b/web/src/assets/images/userMemory/long_term_number.svg @@ -0,0 +1,19 @@ + + + 编组 5 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/me.svg b/web/src/assets/images/userMemory/me.svg new file mode 100644 index 00000000..b8fa9e45 --- /dev/null +++ b/web/src/assets/images/userMemory/me.svg @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/memoryInsight.svg b/web/src/assets/images/userMemory/memoryInsight.svg new file mode 100644 index 00000000..7dfa3dcf --- /dev/null +++ b/web/src/assets/images/userMemory/memoryInsight.svg @@ -0,0 +1,32 @@ + + + 编组 26 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/memoryInsight_active.svg b/web/src/assets/images/userMemory/memoryInsight_active.svg new file mode 100644 index 00000000..43c73a4b --- /dev/null +++ b/web/src/assets/images/userMemory/memoryInsight_active.svg @@ -0,0 +1,15 @@ + + + 热点洞察 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/mp3.svg b/web/src/assets/images/userMemory/mp3.svg new file mode 100644 index 00000000..6bc6f2c6 --- /dev/null +++ b/web/src/assets/images/userMemory/mp3.svg @@ -0,0 +1,60 @@ + + + 编组 9 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + MP3 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/pause.svg b/web/src/assets/images/userMemory/pause.svg new file mode 100644 index 00000000..95e5d0ca --- /dev/null +++ b/web/src/assets/images/userMemory/pause.svg @@ -0,0 +1,20 @@ + + + 播放 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/play.svg b/web/src/assets/images/userMemory/play.svg new file mode 100644 index 00000000..a3caf5be --- /dev/null +++ b/web/src/assets/images/userMemory/play.svg @@ -0,0 +1,15 @@ + + + 播放 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/play_opacity.svg b/web/src/assets/images/userMemory/play_opacity.svg new file mode 100644 index 00000000..78de47cf --- /dev/null +++ b/web/src/assets/images/userMemory/play_opacity.svg @@ -0,0 +1,15 @@ + + + 播放 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/play_speed.svg b/web/src/assets/images/userMemory/play_speed.svg new file mode 100644 index 00000000..0245a19e --- /dev/null +++ b/web/src/assets/images/userMemory/play_speed.svg @@ -0,0 +1,13 @@ + + + iconfont-PREV + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/question.svg b/web/src/assets/images/userMemory/question.svg new file mode 100644 index 00000000..f8b0fee4 --- /dev/null +++ b/web/src/assets/images/userMemory/question.svg @@ -0,0 +1,15 @@ + + + 问号小 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/refresh.svg b/web/src/assets/images/userMemory/refresh.svg new file mode 100644 index 00000000..46627009 --- /dev/null +++ b/web/src/assets/images/userMemory/refresh.svg @@ -0,0 +1,20 @@ + + + 重新生成 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/retrieval_number.svg b/web/src/assets/images/userMemory/retrieval_number.svg new file mode 100644 index 00000000..0257ad37 --- /dev/null +++ b/web/src/assets/images/userMemory/retrieval_number.svg @@ -0,0 +1,22 @@ + + + 编组 5 + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/up_border.svg b/web/src/assets/images/userMemory/up_border.svg index a7fe9978..9435cb19 100644 --- a/web/src/assets/images/userMemory/up_border.svg +++ b/web/src/assets/images/userMemory/up_border.svg @@ -1,12 +1,12 @@ 下拉备份 - - + + - - - + + + diff --git a/web/src/assets/images/userMemory/user.png b/web/src/assets/images/userMemory/user.png new file mode 100644 index 00000000..671ab044 Binary files /dev/null and b/web/src/assets/images/userMemory/user.png differ diff --git a/web/src/assets/images/userMemory/userProfile.svg b/web/src/assets/images/userMemory/userProfile.svg new file mode 100644 index 00000000..fd996bf0 --- /dev/null +++ b/web/src/assets/images/userMemory/userProfile.svg @@ -0,0 +1,22 @@ + + + 知识库 + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/userProfile_active.svg b/web/src/assets/images/userMemory/userProfile_active.svg new file mode 100644 index 00000000..76a1e9e8 --- /dev/null +++ b/web/src/assets/images/userMemory/userProfile_active.svg @@ -0,0 +1,22 @@ + + + 知识库 + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/agent_arbitration.png b/web/src/assets/images/workflow/agent_arbitration.png deleted file mode 100644 index d555e3e2..00000000 Binary files a/web/src/assets/images/workflow/agent_arbitration.png and /dev/null differ diff --git a/web/src/assets/images/workflow/agent_collaboration.png b/web/src/assets/images/workflow/agent_collaboration.png deleted file mode 100644 index 7a92aecf..00000000 Binary files a/web/src/assets/images/workflow/agent_collaboration.png and /dev/null differ diff --git a/web/src/assets/images/workflow/agent_scheduling.png b/web/src/assets/images/workflow/agent_scheduling.png deleted file mode 100644 index 97028422..00000000 Binary files a/web/src/assets/images/workflow/agent_scheduling.png and /dev/null differ diff --git a/web/src/assets/images/workflow/aggregator.png b/web/src/assets/images/workflow/aggregator.png deleted file mode 100644 index 6253733a..00000000 Binary files a/web/src/assets/images/workflow/aggregator.png and /dev/null differ diff --git a/web/src/assets/images/workflow/aggregator.svg b/web/src/assets/images/workflow/aggregator.svg new file mode 100644 index 00000000..c757e1a1 --- /dev/null +++ b/web/src/assets/images/workflow/aggregator.svg @@ -0,0 +1,31 @@ + + + 编组 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/assigner.png b/web/src/assets/images/workflow/assigner.png deleted file mode 100644 index 4370bfdd..00000000 Binary files a/web/src/assets/images/workflow/assigner.png and /dev/null differ diff --git a/web/src/assets/images/workflow/assigner.svg b/web/src/assets/images/workflow/assigner.svg new file mode 100644 index 00000000..c653694f --- /dev/null +++ b/web/src/assets/images/workflow/assigner.svg @@ -0,0 +1,30 @@ + + + 编组 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/break.png b/web/src/assets/images/workflow/break.png deleted file mode 100644 index 473ab068..00000000 Binary files a/web/src/assets/images/workflow/break.png and /dev/null differ diff --git a/web/src/assets/images/workflow/break.svg b/web/src/assets/images/workflow/break.svg new file mode 100644 index 00000000..aefc203a --- /dev/null +++ b/web/src/assets/images/workflow/break.svg @@ -0,0 +1,30 @@ + + + 编组 14 + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/classification.png b/web/src/assets/images/workflow/classification.png deleted file mode 100644 index 87d34bb8..00000000 Binary files a/web/src/assets/images/workflow/classification.png and /dev/null differ diff --git a/web/src/assets/images/workflow/code_execution.png b/web/src/assets/images/workflow/code_execution.png deleted file mode 100644 index 7f802b3c..00000000 Binary files a/web/src/assets/images/workflow/code_execution.png and /dev/null differ diff --git a/web/src/assets/images/workflow/code_execution.svg b/web/src/assets/images/workflow/code_execution.svg new file mode 100644 index 00000000..4d749ddd --- /dev/null +++ b/web/src/assets/images/workflow/code_execution.svg @@ -0,0 +1,27 @@ + + + 编组 13 + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/condition.png b/web/src/assets/images/workflow/condition.png deleted file mode 100644 index a0bf9160..00000000 Binary files a/web/src/assets/images/workflow/condition.png and /dev/null differ diff --git a/web/src/assets/images/workflow/condition.svg b/web/src/assets/images/workflow/condition.svg new file mode 100644 index 00000000..addb1122 --- /dev/null +++ b/web/src/assets/images/workflow/condition.svg @@ -0,0 +1,27 @@ + + + 编组 14 + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/delete.svg b/web/src/assets/images/workflow/delete.svg new file mode 100644 index 00000000..238a729c --- /dev/null +++ b/web/src/assets/images/workflow/delete.svg @@ -0,0 +1,23 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/delete_hover.svg b/web/src/assets/images/workflow/delete_hover.svg new file mode 100644 index 00000000..2f145453 --- /dev/null +++ b/web/src/assets/images/workflow/delete_hover.svg @@ -0,0 +1,23 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/end.png b/web/src/assets/images/workflow/end.png deleted file mode 100644 index 7f4628c6..00000000 Binary files a/web/src/assets/images/workflow/end.png and /dev/null differ diff --git a/web/src/assets/images/workflow/end.svg b/web/src/assets/images/workflow/end.svg new file mode 100644 index 00000000..7c8eb34e --- /dev/null +++ b/web/src/assets/images/workflow/end.svg @@ -0,0 +1,23 @@ + + + 编组 13 + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/file_fold.svg b/web/src/assets/images/workflow/file_fold.svg new file mode 100644 index 00000000..b50f10de --- /dev/null +++ b/web/src/assets/images/workflow/file_fold.svg @@ -0,0 +1,13 @@ + + + 文件夹 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/http_request.png b/web/src/assets/images/workflow/http_request.png deleted file mode 100644 index 64e55d36..00000000 Binary files a/web/src/assets/images/workflow/http_request.png and /dev/null differ diff --git a/web/src/assets/images/workflow/http_request.svg b/web/src/assets/images/workflow/http_request.svg new file mode 100644 index 00000000..36c8995f --- /dev/null +++ b/web/src/assets/images/workflow/http_request.svg @@ -0,0 +1,27 @@ + + + 编组 12 + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/iteration.png b/web/src/assets/images/workflow/iteration.png deleted file mode 100644 index dd73767b..00000000 Binary files a/web/src/assets/images/workflow/iteration.png and /dev/null differ diff --git a/web/src/assets/images/workflow/iteration.svg b/web/src/assets/images/workflow/iteration.svg new file mode 100644 index 00000000..5bc4d840 --- /dev/null +++ b/web/src/assets/images/workflow/iteration.svg @@ -0,0 +1,32 @@ + + + 编组 6 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/llm.png b/web/src/assets/images/workflow/llm.png deleted file mode 100644 index 5d9e7465..00000000 Binary files a/web/src/assets/images/workflow/llm.png and /dev/null differ diff --git a/web/src/assets/images/workflow/llm.svg b/web/src/assets/images/workflow/llm.svg new file mode 100644 index 00000000..54cee4e0 --- /dev/null +++ b/web/src/assets/images/workflow/llm.svg @@ -0,0 +1,20 @@ + + + 编组 14 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/loop.png b/web/src/assets/images/workflow/loop.png deleted file mode 100644 index a4313229..00000000 Binary files a/web/src/assets/images/workflow/loop.png and /dev/null differ diff --git a/web/src/assets/images/workflow/loop.svg b/web/src/assets/images/workflow/loop.svg new file mode 100644 index 00000000..78fbe8a2 --- /dev/null +++ b/web/src/assets/images/workflow/loop.svg @@ -0,0 +1,27 @@ + + + 编组 15 + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory-read.png b/web/src/assets/images/workflow/memory-read.png deleted file mode 100644 index 4b0cdc1d..00000000 Binary files a/web/src/assets/images/workflow/memory-read.png and /dev/null differ diff --git a/web/src/assets/images/workflow/memory-read.svg b/web/src/assets/images/workflow/memory-read.svg new file mode 100644 index 00000000..d385748e --- /dev/null +++ b/web/src/assets/images/workflow/memory-read.svg @@ -0,0 +1,26 @@ + + + 编组 13 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory-write.png b/web/src/assets/images/workflow/memory-write.png deleted file mode 100644 index 83a50fd4..00000000 Binary files a/web/src/assets/images/workflow/memory-write.png and /dev/null differ diff --git a/web/src/assets/images/workflow/memory-write.svg b/web/src/assets/images/workflow/memory-write.svg new file mode 100644 index 00000000..404275b4 --- /dev/null +++ b/web/src/assets/images/workflow/memory-write.svg @@ -0,0 +1,29 @@ + + + 编组 13 + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory_enhancement.png b/web/src/assets/images/workflow/memory_enhancement.png deleted file mode 100644 index 998c02fe..00000000 Binary files a/web/src/assets/images/workflow/memory_enhancement.png and /dev/null differ diff --git a/web/src/assets/images/workflow/menuFold.svg b/web/src/assets/images/workflow/menuFold.svg new file mode 100644 index 00000000..77dc38ac --- /dev/null +++ b/web/src/assets/images/workflow/menuFold.svg @@ -0,0 +1,17 @@ + + + 收起 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/minus.png b/web/src/assets/images/workflow/minus.png new file mode 100644 index 00000000..8dabd4dc Binary files /dev/null and b/web/src/assets/images/workflow/minus.png differ diff --git a/web/src/assets/images/workflow/model_selection.png b/web/src/assets/images/workflow/model_selection.png deleted file mode 100644 index e3e93962..00000000 Binary files a/web/src/assets/images/workflow/model_selection.png and /dev/null differ diff --git a/web/src/assets/images/workflow/model_voting.png b/web/src/assets/images/workflow/model_voting.png deleted file mode 100644 index 8324541e..00000000 Binary files a/web/src/assets/images/workflow/model_voting.png and /dev/null differ diff --git a/web/src/assets/images/workflow/node_plus.png b/web/src/assets/images/workflow/node_plus.png new file mode 100644 index 00000000..61e83c65 Binary files /dev/null and b/web/src/assets/images/workflow/node_plus.png differ diff --git a/web/src/assets/images/workflow/output_audit.png b/web/src/assets/images/workflow/output_audit.png deleted file mode 100644 index 50128f82..00000000 Binary files a/web/src/assets/images/workflow/output_audit.png and /dev/null differ diff --git a/web/src/assets/images/workflow/parallel.png b/web/src/assets/images/workflow/parallel.png deleted file mode 100644 index e77d79d8..00000000 Binary files a/web/src/assets/images/workflow/parallel.png and /dev/null differ diff --git a/web/src/assets/images/workflow/parameter_extraction.png b/web/src/assets/images/workflow/parameter_extraction.png deleted file mode 100644 index d4b50ee0..00000000 Binary files a/web/src/assets/images/workflow/parameter_extraction.png and /dev/null differ diff --git a/web/src/assets/images/workflow/parameter_extraction.svg b/web/src/assets/images/workflow/parameter_extraction.svg new file mode 100644 index 00000000..f3472516 --- /dev/null +++ b/web/src/assets/images/workflow/parameter_extraction.svg @@ -0,0 +1,22 @@ + + + 编组 15 + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/plus.png b/web/src/assets/images/workflow/plus.png new file mode 100644 index 00000000..05f5066a Binary files /dev/null and b/web/src/assets/images/workflow/plus.png differ diff --git a/web/src/assets/images/workflow/process_evolution.png b/web/src/assets/images/workflow/process_evolution.png deleted file mode 100644 index 8262c00d..00000000 Binary files a/web/src/assets/images/workflow/process_evolution.png and /dev/null differ diff --git a/web/src/assets/images/workflow/question-classifier.png b/web/src/assets/images/workflow/question-classifier.png index 754a0a62..9a95e4ab 100644 Binary files a/web/src/assets/images/workflow/question-classifier.png and b/web/src/assets/images/workflow/question-classifier.png differ diff --git a/web/src/assets/images/workflow/question-classifier.svg b/web/src/assets/images/workflow/question-classifier.svg new file mode 100644 index 00000000..3a85ff8b --- /dev/null +++ b/web/src/assets/images/workflow/question-classifier.svg @@ -0,0 +1,23 @@ + + + 编组 14 + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/rag.png b/web/src/assets/images/workflow/rag.png deleted file mode 100644 index 3749dbfa..00000000 Binary files a/web/src/assets/images/workflow/rag.png and /dev/null differ diff --git a/web/src/assets/images/workflow/rag.svg b/web/src/assets/images/workflow/rag.svg new file mode 100644 index 00000000..e9648fc8 --- /dev/null +++ b/web/src/assets/images/workflow/rag.svg @@ -0,0 +1,26 @@ + + + 编组 14 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/reasoning_control.png b/web/src/assets/images/workflow/reasoning_control.png deleted file mode 100644 index 649e165c..00000000 Binary files a/web/src/assets/images/workflow/reasoning_control.png and /dev/null differ diff --git a/web/src/assets/images/workflow/refresh_active.svg b/web/src/assets/images/workflow/refresh_active.svg new file mode 100644 index 00000000..f9b0b3d8 --- /dev/null +++ b/web/src/assets/images/workflow/refresh_active.svg @@ -0,0 +1,18 @@ + + + 刷新 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/robot-2-line@2x.png b/web/src/assets/images/workflow/robot-2-line@2x.png deleted file mode 100644 index f1dc247e..00000000 Binary files a/web/src/assets/images/workflow/robot-2-line@2x.png and /dev/null differ diff --git a/web/src/assets/images/workflow/self_optimization.png b/web/src/assets/images/workflow/self_optimization.png deleted file mode 100644 index 08ed8598..00000000 Binary files a/web/src/assets/images/workflow/self_optimization.png and /dev/null differ diff --git a/web/src/assets/images/workflow/self_reflection.png b/web/src/assets/images/workflow/self_reflection.png deleted file mode 100644 index 099aac60..00000000 Binary files a/web/src/assets/images/workflow/self_reflection.png and /dev/null differ diff --git a/web/src/assets/images/workflow/sensitive_detection.png b/web/src/assets/images/workflow/sensitive_detection.png deleted file mode 100644 index 637a4f13..00000000 Binary files a/web/src/assets/images/workflow/sensitive_detection.png and /dev/null differ diff --git a/web/src/assets/images/workflow/start.png b/web/src/assets/images/workflow/start.png deleted file mode 100644 index f6828988..00000000 Binary files a/web/src/assets/images/workflow/start.png and /dev/null differ diff --git a/web/src/assets/images/workflow/start.svg b/web/src/assets/images/workflow/start.svg new file mode 100644 index 00000000..7b89c1f7 --- /dev/null +++ b/web/src/assets/images/workflow/start.svg @@ -0,0 +1,21 @@ + + + 编组 12 + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/task_planning.png b/web/src/assets/images/workflow/task_planning.png deleted file mode 100644 index 33f322fd..00000000 Binary files a/web/src/assets/images/workflow/task_planning.png and /dev/null differ diff --git a/web/src/assets/images/workflow/template_rendering.png b/web/src/assets/images/workflow/template_rendering.png deleted file mode 100644 index 064caeb6..00000000 Binary files a/web/src/assets/images/workflow/template_rendering.png and /dev/null differ diff --git a/web/src/assets/images/workflow/template_rendering.svg b/web/src/assets/images/workflow/template_rendering.svg new file mode 100644 index 00000000..e52bf30d --- /dev/null +++ b/web/src/assets/images/workflow/template_rendering.svg @@ -0,0 +1,26 @@ + + + 编组 13 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/tools.png b/web/src/assets/images/workflow/tools.png deleted file mode 100644 index 49ff2fa4..00000000 Binary files a/web/src/assets/images/workflow/tools.png and /dev/null differ diff --git a/web/src/assets/images/workflow/tools.svg b/web/src/assets/images/workflow/tools.svg new file mode 100644 index 00000000..7c772245 --- /dev/null +++ b/web/src/assets/images/workflow/tools.svg @@ -0,0 +1,23 @@ + + + 编组 6 + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index 639a9109..8df31398 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -2,12 +2,13 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:11:51 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-17 18:39:09 + * @Last Modified time: 2026-03-20 14:25:26 */ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' -import { App } from 'antd' +import { App, Tooltip } from 'antd' import { useTranslation } from 'react-i18next'; +import clsx from 'clsx'; import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import { request } from '@/utils/request' @@ -91,14 +92,17 @@ const AudioRecorder: FC = ({ // Toggle between recording/idle states on click; // swap background image to reflect current state return ( -
+ +
+ ) } diff --git a/web/src/components/BtnTabs/index.tsx b/web/src/components/BtnTabs/index.tsx new file mode 100644 index 00000000..772a4c8d --- /dev/null +++ b/web/src/components/BtnTabs/index.tsx @@ -0,0 +1,49 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-03-19 14:05:09 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-19 14:05:09 + */ +import { type FC } from 'react' +import { Flex } from 'antd'; +import clsx from 'clsx' + +/** A single tab item with a display label and unique key */ +interface Tab { + label: string + key: string +} + +/** Props for the BtnTabs component */ +interface BtnTabsProps { + /** List of tab items to render */ + items: Tab[] + /** Key of the currently active tab */ + activeKey: string + /** Callback fired when a tab is clicked */ + onChange: (key: string) => void; + /** Optional extra class name for the container */ + className?: string; +} + +/** Button-style tab switcher — renders tabs as pill-shaped buttons with active highlight */ +const BtnTabs: FC = ({ items, activeKey, onChange, className }) => { + return ( + + {items.map((tab) => ( +
onChange(tab.key)} + className={clsx('rb:px-2 rb:py-1 rb:rounded-[13px] rb:text-[12px] rb:leading-4.5 rb:cursor-pointer', { + 'rb:bg-[#F6F6F6]': activeKey !== tab.key, + 'rb:bg-[#171719] rb:text-white': activeKey === tab.key, + })} + > + {tab.label} +
+ ))} +
+ ) +} + +export default BtnTabs diff --git a/web/src/components/ButtonCheckbox/index.tsx b/web/src/components/ButtonCheckbox/index.tsx index 18bca7c6..8c52701b 100644 --- a/web/src/components/ButtonCheckbox/index.tsx +++ b/web/src/components/ButtonCheckbox/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-02 15:01:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 13:41:26 + * @Last Modified time: 2026-03-19 20:45:13 */ /** @@ -64,13 +64,11 @@ const ButtonCheckbox: FC = ({ align="center" justify={cicle ? 'center' : 'start'} gap={4} - className={clsx("rb:flex rb:items-center rb:cursor-pointer rb:px-2! rb:border rb:hover:bg-[#F6F6F6]", { - 'rb:size-7 rb:rounded-[14px] rb:border-[0.5px] rb:border-[#EBEBEB]': cicle, - 'rb:rounded-lg rb:text-[12px] rb:h-6': !cicle, + className={clsx("rb:border rb:rounded-lg rb:px-2! rb:text-[12px] rb:h-6 rb:cursor-pointer", { // Checked state: blue background and border - "rb:bg-[rgba(21,94,239,0.06)] rb:border-[rgba(21,94,239,0.25)] rb:hover:bg-[rgba(21,94,239,0.06)] rb:text-[#155EEF]": checked, + "rb:bg-[#FAFAFA] rb:border-[#171719]": checked, // Unchecked state: gray border and dark text - "rb:border-[#DFE4ED] rb:text-[#212332]": !checked, + "rb:border-[#EBEBEB] rb:text-[#212332] rb:hover:bg-[#F0F3F8]": !checked, "rb:opacity-65 rb:cursor-not-allowed!": disabled })} onClick={handleChange} diff --git a/web/src/components/Charts/AreaLineChart.tsx b/web/src/components/Charts/AreaLineChart.tsx new file mode 100644 index 00000000..40bfabb1 --- /dev/null +++ b/web/src/components/Charts/AreaLineChart.tsx @@ -0,0 +1,306 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-10 13:36:03 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-25 13:51:52 + */ +/* + * AreaLineChart Component + * + * A reusable area line chart component built with ECharts that displays time-series data + * with gradient-filled areas under the lines. Supports multiple data series with + * customizable colors and responsive behavior. + * + * Features: + * - Multiple line series with gradient area fills + * - Gradient line colors (white to color to white) + * - Customizable x-axis key for flexible data structures + * - Date-based x-axis with formatted labels (DD/MM) + * - Responsive resizing using ResizeObserver + * - Interactive tooltips on hover + * - Customizable grid layout and colors + * - Legend at the bottom for series identification + * - Empty state when no data is available + * - Smooth rendering with requestAnimationFrame + */ +import { type FC, useEffect, useRef, useMemo } from 'react' +import ReactEcharts from 'echarts-for-react'; +import * as echarts from 'echarts'; + +import { formatDateTime } from '@/utils/format'; +import Empty from '@/components/Empty' + +/** Base configuration for all line series */ +const SeriesConfig = { + type: 'line', + stack: 'Total', + symbol: 'circle', + symbolSize: 5, + showSymbol: true, + label: { + show: false, + position: 'top' + }, + emphasis: { + focus: 'series' + }, +} + +/** Default color palette for area line series */ +const Colors = ['#155EEF', '#FFB048', '#4DA8FF'] + +/** + * Data structure for chart data points + * Flexible structure allowing any string key with string or number values + * + * @interface ChartData + * @property {string | number} [key: string] - Dynamic properties for x-axis and data series + */ +export interface ChartData { + [key: string]: string | number; +} + +/** + * Props for the AreaLineChart component + * + * @interface AreaLineChartProps + * @property {string} xAxisKey - Key name in chartData to use for x-axis values + * @property {ChartData[]} chartData - Array of data points with dynamic properties + * @property {Record} seriesList - Map of data keys to display names + * @property {string} [className] - Additional CSS classes for the container + * @property {number} [height] - Height of the chart in pixels + * @property {string[]} [colors] - Custom color array for line series and gradients + * @property {any} [grid] - ECharts grid configuration for chart positioning + */ +interface AreaLineChartProps { + xAxisKey: string; + chartData: ChartData[]; + seriesList: Record; + className?: string; + height?: number; + colors?: string[]; + grid?: any; + lineStyle?: any; + showLegend?: boolean; + smooth?: boolean; +} + +/** + * AreaLineChart Component + * + * Renders a multi-series area line chart with gradient fills. + * The area gradient goes from the series color at the top to white at the bottom. + * The line gradient goes from white to the series color and back to white. + * Automatically resizes when container dimensions change. + * + * @param {AreaLineChartProps} props - Component props + * @returns {JSX.Element} Rendered area line chart or empty state + * + * @example + * ```tsx + * + * ``` + */ +const AreaLineChart: FC = ({ + xAxisKey, + chartData, + seriesList, + height, + colors = Colors, + grid = { + top: 7, + left: 4, + right: 16, + bottom: 32, + containLabel: true + }, + lineStyle, + showLegend = true, + smooth = true +}) => { + /** Reference to the ECharts instance for programmatic control */ + const chartRef = useRef(null); + /** Flag to prevent multiple simultaneous resize operations */ + const resizeScheduledRef = useRef(false) + + /** + * Generate series configuration for each data series with gradient effects + * Creates area fills with vertical gradients (color to white) + * and line colors with horizontal gradients (white to color to white) + * + * @returns {Array} Array of ECharts series configurations with gradient styles + */ + const getSeries = () => { + return Object.entries(seriesList).map(([key, name], index) => ({ + ...SeriesConfig, + name: name, + data: chartData.map(vo => vo[key as keyof ChartData]), + areaStyle: { + opacity: 0.8, + color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [ + { + offset: 0, + color: colors[index] + }, + { + offset: 1, + color: '#FFFFFF' + } + ]) + }, + lineStyle: lineStyle || { + width: 3, + color: new echarts.graphic.LinearGradient(0, 0, 1, 0, [ + { + offset: 0, + color: '#FFFFFF' + }, + { + offset: 0.8, + color: colors[index] + }, + { + offset: 1, + color: '#FFFFFF' + } + ]) + }, + smooth + })) + } + /** + * Memoized legend data to prevent unnecessary recalculations + * Formats series list for display in chart legend + */ + const formatSeriesList = useMemo(() => { + return Object.entries(seriesList).map(([_key, name]) => ({ + ...SeriesConfig, + name: name, + })) + }, [seriesList]) + + /** + * Set up responsive behavior using ResizeObserver + * Resizes chart when parent container dimensions change + */ + useEffect(() => { + const handleResize = () => { + if (chartRef.current && !resizeScheduledRef.current) { + resizeScheduledRef.current = true + requestAnimationFrame(() => { + chartRef.current?.getEchartsInstance().resize(); + resizeScheduledRef.current = false + }); + } + } + + const resizeObserver = new ResizeObserver(handleResize) + const chartElement = chartRef.current?.getEchartsInstance().getDom().parentElement + if (chartElement) { + resizeObserver.observe(chartElement) + } + + return () => { + resizeObserver.disconnect() + } + }, [chartData]) + + return ( +
+ {chartData && chartData.length > 0 + ? formatDateTime(item[xAxisKey], 'DD/MM')), + boundaryGap: false, + axisLabel: { + color: '#5B6167', + fontFamily: 'PingFangSC, PingFang SC', + lineHeight: 17, + }, + axisLine: { + show: false, + lineStyle: { + color: '#EBEBEB', + } + }, + splitLine: { + show: false, + }, + axisTick: { + show: false + } + }, + yAxis: { + type: 'value', + axisLabel: { + color: '#A8A9AA', + fontFamily: 'PingFangSC, PingFang SC', + align: 'right', + lineHeight: 17, + }, + axisLine: { + lineStyle: { + color: '#EBEBEB', + } + }, + }, + series: getSeries() + }} + style={{ height: `${height}px`, width: '100%', minWidth: '100%', boxSizing: 'border-box' }} + opts={{ renderer: 'canvas' }} + notMerge={true} + lazyUpdate={true} + /> + : + } +
+ ) +} + +export default AreaLineChart diff --git a/web/src/components/Charts/BarChart.tsx b/web/src/components/Charts/BarChart.tsx new file mode 100644 index 00000000..e476a2cb --- /dev/null +++ b/web/src/components/Charts/BarChart.tsx @@ -0,0 +1,295 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-10 13:36:03 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-25 13:49:04 + */ +/* + * BarChart Component + * + * A reusable area line chart component built with ECharts that displays time-series data + * with gradient-filled areas under the lines. Supports multiple data series with + * customizable colors and responsive behavior. + * + * Features: + * - Multiple line series with gradient area fills + * - Gradient line colors (white to color to white) + * - Customizable x-axis key for flexible data structures + * - Date-based x-axis with formatted labels (DD/MM) + * - Responsive resizing using ResizeObserver + * - Interactive tooltips on hover + * - Customizable grid layout and colors + * - Legend at the bottom for series identification + * - Empty state when no data is available + * - Smooth rendering with requestAnimationFrame + */ +import { type FC, useEffect, useRef, useMemo } from 'react' +import ReactEcharts from 'echarts-for-react'; +import * as echarts from 'echarts'; + +import { formatDateTime } from '@/utils/format'; +import Empty from '@/components/Empty' + +/** Base configuration for all line series */ +const SeriesConfig = { + type: 'bar', + stack: 'Total', + symbol: 'circle', + symbolSize: 5, + showSymbol: true, + label: { + show: false, + position: 'top' + }, + emphasis: { + focus: 'series' + }, + showBackground: true, +} + +/** Default color palette for area line series */ +const Colors = ['#155EEF', '#FFB048', '#4DA8FF'] + +/** + * Data structure for chart data points + * Flexible structure allowing any string key with string or number values + * + * @interface ChartData + * @property {string | number} [key: string] - Dynamic properties for x-axis and data series + */ +export interface ChartData { + [key: string]: string | number; +} + +/** + * Props for the BarChart component + * + * @interface BarChartProps + * @property {string} xAxisKey - Key name in chartData to use for x-axis values + * @property {ChartData[]} chartData - Array of data points with dynamic properties + * @property {Record} seriesList - Map of data keys to display names + * @property {string} [className] - Additional CSS classes for the container + * @property {number} [height] - Height of the chart in pixels + * @property {string[]} [colors] - Custom color array for line series and gradients + * @property {any} [grid] - ECharts grid configuration for chart positioning + */ +interface BarChartProps { + xAxisKey: string; + chartData: ChartData[]; + seriesList: Record; + className?: string; + height?: number; + colors?: string[]; + grid?: any; + itemStyle?: any; + showLegend?: boolean; + showBackground?: boolean; +} + +/** + * BarChart Component + * + * Renders a multi-series area line chart with gradient fills. + * The area gradient goes from the series color at the top to white at the bottom. + * The line gradient goes from white to the series color and back to white. + * Automatically resizes when container dimensions change. + * + * @param {BarChartProps} props - Component props + * @returns {JSX.Element} Rendered area line chart or empty state + * + * @example + * ```tsx + * + * ``` + */ +const BarChart: FC = ({ + xAxisKey, + chartData, + seriesList, + height, + colors = Colors, + grid = { + top: 7, + left: 4, + right: 16, + bottom: 32, + containLabel: true + }, + itemStyle, + showLegend = true, + showBackground = true, +}) => { + /** Reference to the ECharts instance for programmatic control */ + const chartRef = useRef(null); + /** Flag to prevent multiple simultaneous resize operations */ + const resizeScheduledRef = useRef(false) + + /** + * Generate series configuration for each data series with gradient effects + * Creates area fills with vertical gradients (color to white) + * and line colors with horizontal gradients (white to color to white) + * + * @returns {Array} Array of ECharts series configurations with gradient styles + */ + const getSeries = () => { + return Object.entries(seriesList).map(([key, name], index) => ({ + ...SeriesConfig, + name: name, + data: chartData.map(vo => vo[key as keyof ChartData]), + barWidth: 16, + itemStyle: itemStyle || { + color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [ + { + offset: 0, + color: colors[index] + }, + { + offset: 1, + color: '#FFFFFF' + } + ]), + }, + emphasis: { + itemStyle: { + } + }, + barGap: '-100%', + showBackground: showBackground, + })) + } + /** + * Memoized legend data to prevent unnecessary recalculations + * Formats series list for display in chart legend + */ + const formatSeriesList = useMemo(() => { + return Object.entries(seriesList).map(([_key, name]) => ({ + ...SeriesConfig, + name: name, + })) + }, [seriesList]) + + /** + * Set up responsive behavior using ResizeObserver + * Resizes chart when parent container dimensions change + */ + useEffect(() => { + const handleResize = () => { + if (chartRef.current && !resizeScheduledRef.current) { + resizeScheduledRef.current = true + requestAnimationFrame(() => { + chartRef.current?.getEchartsInstance().resize(); + resizeScheduledRef.current = false + }); + } + } + + const resizeObserver = new ResizeObserver(handleResize) + const chartElement = chartRef.current?.getEchartsInstance().getDom().parentElement + if (chartElement) { + resizeObserver.observe(chartElement) + } + + return () => { + resizeObserver.disconnect() + } + }, [chartData]) + + return ( +
+ {chartData && chartData.length > 0 + ? formatDateTime(item[xAxisKey], 'DD/MM')), + boundaryGap: false, + axisLabel: { + color: '#5B6167', + fontFamily: 'PingFangSC, PingFang SC', + lineHeight: 17, + }, + axisLine: { + show: false, + itemStyle: { + color: '#EBEBEB', + } + }, + splitLine: { + show: false, + }, + axisTick: { + show: false + } + }, + yAxis: { + type: 'value', + axisLabel: { + color: '#A8A9AA', + fontFamily: 'PingFangSC, PingFang SC', + align: 'right', + lineHeight: 17, + }, + axisLine: { + itemStyle: { + color: '#EBEBEB', + } + }, + }, + series: getSeries() + }} + style={{ height: `${height}px`, width: '100%', minWidth: '100%', boxSizing: 'border-box' }} + opts={{ renderer: 'canvas' }} + notMerge={true} + lazyUpdate={true} + /> + : + } +
+ ) +} + +export default BarChart diff --git a/web/src/components/Charts/GraphNetworkChart.tsx b/web/src/components/Charts/GraphNetworkChart.tsx new file mode 100644 index 00000000..8f4ec796 --- /dev/null +++ b/web/src/components/Charts/GraphNetworkChart.tsx @@ -0,0 +1,200 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-10 14:06:09 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-10 14:06:09 + */ +/** + * GraphNetworkChart Component + * + * A force-directed graph visualization component built with ECharts. + * Displays nodes and edges in an interactive network diagram with physics-based layout. + * Supports zooming, panning, dragging nodes, and click interactions. + */ +import { type FC, useEffect, useRef, type SetStateAction, type Dispatch } from 'react' +import ReactEcharts from 'echarts-for-react'; + +import PageEmpty from '@/components/Empty/PageEmpty' + +// Default color palette for node categories +const Colors = ['#171719', '#155EEF', '#9C6FFF', '#FF8A4C'] + +/** + * Node interface representing a graph node/vertex + */ +export interface Node { + id: string; // Unique identifier for the node + label: string; // Display label for the node + category: number; // Category index for grouping and coloring + symbolSize: number; // Size of the node symbol in pixels + name: string; // Node name (used in ECharts) + itemStyle: { + color: string; // Custom color for this node + } + caption: string; // Additional description or caption + [key: string]: any; // Allow additional custom properties +} + +/** + * Edge interface representing a connection between two nodes + */ +export interface Edge { + id: string; // Unique identifier for the edge + source: string; // Source node ID + target: string; // Target node ID + type: string; // Type/category of the relationship + caption: string; // Description of the relationship + value: number; // Numeric value associated with the edge + weight: number; // Weight/strength of the connection +} + +/** + * Props for the GraphNetworkChart component + */ +interface GraphNetworkChartProps { + nodes: Node[]; // Array of nodes to display in the graph + links: Edge[]; // Array of edges connecting the nodes + categories: { name: string }[]; // Category definitions for node grouping + colors?: string[]; // Optional custom color palette (defaults to Colors) + onNodeClick: Dispatch>; // Callback when a node is clicked +} + +const GraphNetworkChart: FC = ({ + nodes, + links, + categories, + colors = Colors, + onNodeClick, +}) => { + // Reference to the ECharts instance for programmatic control + const chartRef = useRef(null); + + // Flag to prevent multiple simultaneous resize operations (debouncing) + const resizeScheduledRef = useRef(false) + + /** + * Effect: Handle responsive chart resizing + * + * Uses ResizeObserver to detect container size changes and resize the chart accordingly. + * Implements requestAnimationFrame for smooth, debounced resize operations. + * Re-runs when nodes change to ensure proper sizing with new data. + */ + useEffect(() => { + const handleResize = () => { + if (chartRef.current && !resizeScheduledRef.current) { + resizeScheduledRef.current = true + // Use requestAnimationFrame for smooth, optimized resize + requestAnimationFrame(() => { + chartRef.current?.getEchartsInstance().resize(); + resizeScheduledRef.current = false + }); + } + } + + // Observe the chart container for size changes + const resizeObserver = new ResizeObserver(handleResize) + const chartElement = chartRef.current?.getEchartsInstance().getDom().parentElement + if (chartElement) { + resizeObserver.observe(chartElement) + } + + // Cleanup: disconnect observer when component unmounts + return () => { + resizeObserver.disconnect() + } + }, [nodes]) + + return ( +
+ {/* Render chart only if nodes exist, otherwise show empty state */} + {nodes && nodes.length > 0 + ? { + // Only trigger callback for node clicks (not edges or background) + if (params.dataType === 'node') { + onNodeClick(params.data) + } + } + }} + /> + : + } +
+ ) +} + +export default GraphNetworkChart diff --git a/web/src/components/Charts/LineChart.tsx b/web/src/components/Charts/LineChart.tsx new file mode 100644 index 00000000..e5217336 --- /dev/null +++ b/web/src/components/Charts/LineChart.tsx @@ -0,0 +1,260 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-10 13:35:55 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-10 13:35:55 + */ +/* + * LineChart Component + * + * A reusable line chart component built with ECharts for displaying time-series data + * with multiple data series. Supports customizable colors, responsive behavior, + * and interactive tooltips. + * + * Features: + * - Multiple line series with different colors + * - Date-based x-axis with formatted labels (DD/MM) + * - Responsive resizing using ResizeObserver + * - Interactive tooltips on hover + * - Customizable grid layout and colors + * - Legend at the bottom for series identification + * - Empty state when no data is available + * - Smooth rendering with requestAnimationFrame + */ +import { type FC, useEffect, useRef, useMemo } from 'react' +import ReactEcharts from 'echarts-for-react'; + +import { formatDateTime } from '@/utils/format'; +import Empty from '@/components/Empty' + +/** Base configuration for all line series */ +const SeriesConfig = { + type: 'line', + stack: 'Total', + symbol: 'circle', + symbolSize: 5, + showSymbol: true, + label: { + show: false, + position: 'top' + }, + emphasis: { + focus: 'series' + }, +} + +/** Default color palette for line series */ +const Colors = ['#171719', '#155EEF', '#FF5D34'] + +/** + * Data structure for chart data points + * + * @interface ChartData + * @property {string | number} date - Date value for x-axis (timestamp or date string) + * @property {string | number} [key: string] - Dynamic properties for different data series + */ +export interface ChartData { + date: string | number; + [key: string]: string | number; +} + +/** + * Props for the LineChart component + * + * @interface LineChartProps + * @property {ChartData[]} chartData - Array of data points with date and series values + * @property {Record} seriesList - Map of data keys to display names + * @property {string} [className] - Additional CSS classes for the container + * @property {number} [height] - Height of the chart in pixels + * @property {string[]} [colors] - Custom color array for line series + * @property {any} [grid] - ECharts grid configuration for chart positioning + */ +interface LineChartProps { + chartData: ChartData[]; + seriesList: Record; + className?: string; + height?: number; + colors?: string[]; + grid?: any; +} + +/** + * LineChart Component + * + * Renders a multi-series line chart with date-based x-axis. + * Automatically resizes when container dimensions change. + * + * @param {LineChartProps} props - Component props + * @returns {JSX.Element} Rendered line chart or empty state + * + * @example + * ```tsx + * + * ``` + */ +const LineChart: FC = ({ + chartData, + seriesList, + height, + colors = Colors, + grid = { + top: 7, + right: 16, + } +}) => { + /** Reference to the ECharts instance for programmatic control */ + const chartRef = useRef(null); + /** Flag to prevent multiple simultaneous resize operations */ + const resizeScheduledRef = useRef(false) + + /** + * Generate series configuration for each data series + * Maps seriesList keys to chart series with corresponding data and colors + * + * @returns {Array} Array of ECharts series configurations + */ + const getSeries = () => { + return Object.entries(seriesList).map(([key, name], index) => ({ + ...SeriesConfig, + name: name, + data: chartData.map(vo => vo[key as keyof ChartData]), + lineStyle: { + width: 2, + color: colors[index] + }, + })) + } + /** + * Memoized legend data to prevent unnecessary recalculations + * Formats series list for display in chart legend + */ + const formatSeriesList = useMemo(() => { + return Object.entries(seriesList).map(([_key, name]) => ({ + ...SeriesConfig, + name: name, + })) + }, [seriesList]) + + /** + * Set up responsive behavior using ResizeObserver + * Resizes chart when parent container dimensions change + */ + useEffect(() => { + const handleResize = () => { + if (chartRef.current && !resizeScheduledRef.current) { + resizeScheduledRef.current = true + requestAnimationFrame(() => { + chartRef.current?.getEchartsInstance().resize(); + resizeScheduledRef.current = false + }); + } + } + + const resizeObserver = new ResizeObserver(handleResize) + const chartElement = chartRef.current?.getEchartsInstance().getDom().parentElement + if (chartElement) { + resizeObserver.observe(chartElement) + } + + return () => { + resizeObserver.disconnect() + } + }, [chartData]) + + return ( +
+ {chartData && chartData.length > 0 + ? formatDateTime(item.date, 'DD/MM')), + boundaryGap: false, + axisLabel: { + color: '#5B6167', + fontFamily: 'PingFangSC, PingFang SC', + lineHeight: 17, + }, + axisLine: { + show: false, + lineStyle: { + color: '#EBEBEB', + } + }, + splitLine: { + show: false, + }, + axisTick: { + show: false + } + }, + yAxis: { + type: 'value', + axisLabel: { + color: '#A8A9AA', + fontFamily: 'PingFangSC, PingFang SC', + align: 'right', + lineHeight: 17, + }, + axisLine: { + lineStyle: { + color: '#EBEBEB', + } + }, + }, + series: getSeries() + }} + style={{ height: '100%', width: '100%', minWidth: '100%', boxSizing: 'border-box' }} + opts={{ renderer: 'canvas' }} + notMerge={true} + lazyUpdate={true} + /> + : + } +
+ ) +} + +export default LineChart diff --git a/web/src/components/Charts/PieChart.tsx b/web/src/components/Charts/PieChart.tsx new file mode 100644 index 00000000..b0c67549 --- /dev/null +++ b/web/src/components/Charts/PieChart.tsx @@ -0,0 +1,204 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-10 13:35:45 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-16 11:34:30 + */ +/* + * PieChart Component + * + * A reusable pie chart component built with ECharts that displays data distribution + * in a donut chart format with customizable colors and responsive behavior. + * + * Features: + * - Donut-style pie chart with percentage labels + * - Customizable color palette + * - Responsive resizing using ResizeObserver + * - Hover tooltips showing percentage values + * - Legend at the bottom with horizontal layout + * - Empty state when no data is available + * - Shadow effects for better visual depth + */ +import { type FC, useEffect, useRef } from 'react' +import ReactEcharts from 'echarts-for-react'; + +import Empty from '@/components/Empty' + +/** Default color palette for pie chart segments */ +const Colors = ['#171719', '#155EEF', '#4DA8FF', '#9C6FFF', '#ABEBFF', '#DFE4ED'] + +/** + * Data structure for each pie chart segment + * + * @interface ChartData + * @property {string} name - Label for the segment (displayed in legend) + * @property {number} value - Numeric value for the segment (determines size) + */ +export interface ChartData { + name: string; + value: number; +} + +/** + * Props for the PieChart component + * + * @interface PieChartProps + * @property {ChartData[]} chartData - Array of data points to display in the chart + * @property {number} [height=260] - Height of the chart in pixels + * @property {string[]} [colors] - Custom color array for chart segments (defaults to Colors) + */ +interface PieChartProps { + chartData: ChartData[]; + height?: number; + colors?: string[]; + itemGap?: number; + seriesWidth?: number; + seriesHeight?: number; + seriesLabel?: boolean; + seriesTop?: number; +} + +/** + * PieChart Component + * + * Renders a donut-style pie chart with percentage labels and legend. + * Automatically resizes when container dimensions change. + * + * @param {PieChartProps} props - Component props + * @returns {JSX.Element} Rendered pie chart or empty state + * + * @example + * ```tsx + * + * ``` + */ +const PieChart: FC = ({ + chartData, + height = 260, + seriesWidth = 182, + seriesHeight = 182, + colors = Colors, + itemGap = 48, + seriesLabel = true, + seriesTop = 24, +}) => { + /** Reference to the ECharts instance for programmatic control */ + const chartRef = useRef(null); + /** Flag to prevent multiple simultaneous resize operations */ + const resizeScheduledRef = useRef(false) + + /** + * Set up responsive behavior using ResizeObserver + * Resizes chart when parent container dimensions change + */ + useEffect(() => { + const handleResize = () => { + if (chartRef.current && !resizeScheduledRef.current) { + resizeScheduledRef.current = true + // Use requestAnimationFrame for smooth resize performance + requestAnimationFrame(() => { + chartRef.current?.getEchartsInstance().resize(); + resizeScheduledRef.current = false + }); + } + } + + const resizeObserver = new ResizeObserver(handleResize) + const chartElement = chartRef.current?.getEchartsInstance().getDom().parentElement + if (chartElement) { + resizeObserver.observe(chartElement) + } + + // Cleanup: disconnect observer when component unmounts + return () => { + resizeObserver.disconnect() + } + }, [chartData]) + + return ( +
+ {chartData && chartData.length > 0 + ? + : + } +
+ ) +} + +export default PieChart diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index c7d3cffb..34472a2e 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -2,15 +2,15 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:17 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 19:45:40 + * @Last Modified time: 2026-03-26 13:32:29 */ import { type FC, useRef, useEffect, useState } from 'react' import clsx from 'clsx' import Markdown from '@/components/Markdown' import type { ChatContentProps } from './types' -import { Spin, Divider, Space, Image, Flex } from 'antd' +import { Spin, Divider, Space, Image, Flex, Button } from 'antd' import { SoundOutlined } from '@ant-design/icons' - +import { t } from 'i18next' const getFileUrl = (file: any) => { return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined) @@ -29,7 +29,8 @@ const ChatContent: FC = ({ labelPosition = 'bottom', labelFormat, errorDesc, - renderRuntime + renderRuntime, + onSend }) => { // Scroll container reference for controlling auto-scroll to bottom const scrollContainerRef = useRef<(HTMLDivElement | null)>(null) @@ -38,7 +39,8 @@ const ChatContent: FC = ({ const audioRef = useRef(null) const [playingIndex, setPlayingIndex] = useState(null) - const handlePlay = (index: number, audio_url: string) => { + const handlePlay = (index: number, audio_url: string, audio_status?: string) => { + if (audio_status !== 'completed' && !audio_status) return if (playingIndex === index) { audioRef.current?.pause() setPlayingIndex(null) @@ -114,7 +116,7 @@ const ChatContent: FC = ({ : <> {/* Top label (such as timestamp, username, etc.) */} {labelPosition === 'top' && -
+
{labelFormat(item)}
} @@ -162,26 +164,53 @@ const ChatContent: FC = ({ })} } {/* Message bubble */} -
+ {item.status &&
} {item.subContent && renderRuntime && renderRuntime(item, index)} {/* Render message content using Markdown component */} + {item.meta_data?.suggested_questions && item.meta_data?.suggested_questions?.length > 0 && + {item.meta_data?.suggested_questions?.map((question, idx) => ( + + ))} + } + {item.meta_data?.citations && item.meta_data?.citations.length > 0 &&
+
{t('memoryConversation.citations')}
+ {item.meta_data?.citations?.map((citation, idx) => ( + + ))} +
} {item.meta_data?.audio_url && <> - {playingIndex !== index - ? handlePlay(index, item.meta_data?.audio_url!)} /> + {playingIndex !== index && item.meta_data?.audio_status === 'pending' + ? + : playingIndex !== index + ? handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> :
handlePlay(index, item.meta_data?.audio_url!)} + onClick={() => handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> } @@ -189,7 +218,7 @@ const ChatContent: FC = ({
{/* Bottom label (such as timestamp, username, etc.) */} {labelPosition === 'bottom' && -
+
{labelFormat(item)}
} diff --git a/web/src/components/Chat/ChatInput.tsx b/web/src/components/Chat/ChatInput.tsx index aa0dd2f6..6495ff06 100644 --- a/web/src/components/Chat/ChatInput.tsx +++ b/web/src/components/Chat/ChatInput.tsx @@ -2,15 +2,12 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:14 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 18:44:51 + * @Last Modified time: 2026-03-23 17:46:25 */ -import { type FC, useEffect, useMemo } from 'react' -import { Flex, Input, Form, Spin } from 'antd' +import { type FC, useEffect, useMemo, useState } from 'react' +import { Flex, Input, Spin } from 'antd' import clsx from 'clsx' -import SendIcon from '@/assets/images/conversation/send.svg' -import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg' -import LoadingIcon from '@/assets/images/conversation/loading.svg' import type { ChatInputProps } from './types' /** @@ -27,37 +24,27 @@ const ChatInput: FC = ({ className = '', onChange }) => { - const [form] = Form.useForm() - const values = Form.useWatch([], form) - // Monitor form value changes to control send button state + const [inputValue, setInputValue] = useState('') + const [isFocus, setIsFocus] = useState(false) - // Clear form when external message is empty + // Clear input when external message is cleared useEffect(() => { - if (!message) { - form.setFieldsValue({ - message: undefined, - }) - } - }, [form, message]) - + if (!message) setInputValue('') + }, [message]) + // Clear input when loading useEffect(() => { - if (loading) { - form.setFieldsValue({ - message: undefined, - }) - } + if (loading) setInputValue('') }, [loading]) - const handleDelete = (file: any) => { fileChange?.(fileList?.filter(item => { return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl : item.url && file.url ? item.url !== file.url - : item.uid !== file.uid + : item.uid !== file.uid }) || []) } - // Convert file object to preview URL + const previewFileList = useMemo(() => { return fileList?.map(file => ({ ...file, @@ -66,24 +53,27 @@ const ChatInput: FC = ({ }, [fileList]) const handleSend = () => { - if (loading || !values || !values?.message || values?.message?.trim() === '') return - onSend(values.message) + if (loading || !inputValue || inputValue.trim() === '') return + onSend(inputValue) } - console.log('previewFileList', previewFileList) + const canSend = !loading && inputValue.trim() !== '' return (
- - {previewFileList.length > 0 &&
+ + {previewFileList.length > 0 &&
+ {previewFileList.map((file) => { - if (file.type.includes('image')) { + if (file.type?.includes('image')) { return ( -
- {file.name} + {file.name}
handleDelete(file)} @@ -92,30 +82,30 @@ const ChatInput: FC = ({ ) } - if (file.type.includes('video')) { + if (file.type?.includes('video')) { return ( -
-
) } - if (file.type.includes('audio')) { + if (file.type?.includes('audio')) { return ( -
-
@@ -124,68 +114,86 @@ const ChatInput: FC = ({ } return ( -
- {file.type.includes('pdf') - ?
- : (file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) - ?
- : (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) - ?
- : null - } +
{file.name}
-
{file.type} · {file.size}
+
{file.type?.split('/')[file.type?.split('/').length - 1]} · {file.size}
handleDelete(file)} >
-
+
) })} -
} - {/* Message input form */} -
- - onChange?.(e.target.value)} - onKeyDown={(e) => { - // Enter to send, Shift+Enter for new line - if (e.key === 'Enter' && !e.shiftKey && (e.target as HTMLTextAreaElement).value?.trim() !== '' && !loading) { - e.preventDefault(); - handleSend(); - } - }} - /> - -
+ +
} + {/* Message input area */} + { + setInputValue(e.target.value) + onChange?.(e.target.value) + }} + onKeyDown={(e) => { + // Enter to send, Shift+Enter for new line + if (e.key === 'Enter' && !e.shiftKey && (e.target as HTMLTextAreaElement).value?.trim() !== '' && !loading) { + e.preventDefault(); + handleSend(); + } + }} + onFocus={() => setIsFocus(true)} + onBlur={() => setIsFocus(false)} + /> {/* Bottom action area */} - - {/* Child component content (such as buttons) */} +
{children}
-
- {/* Send button - display different icons based on state */} - {loading - ? - : !values || !values?.message || values?.message?.trim() === '' - ? - : - } -
+ +
+
diff --git a/web/src/components/Chat/ChatToolbar.tsx b/web/src/components/Chat/ChatToolbar.tsx index 936e7e63..3fbc0e3a 100644 --- a/web/src/components/Chat/ChatToolbar.tsx +++ b/web/src/components/Chat/ChatToolbar.tsx @@ -2,12 +2,11 @@ * @Author: ZhaoYing * @Date: 2026-03-17 14:22:25 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-19 18:59:37 + * @Last Modified time: 2026-03-23 17:42:38 */ // Toolbar component for chat input area, supporting file upload, audio recording, and variable configuration import { useRef, forwardRef, useImperativeHandle, type ReactNode, useEffect } from 'react' -import { Flex, Dropdown, Divider, App, Form, type MenuProps } from 'antd' -import { SettingOutlined } from '@ant-design/icons' +import { Flex, Dropdown, Divider, App, Form, type MenuProps, Tooltip } from 'antd' import { useTranslation } from 'react-i18next' import clsx from 'clsx' @@ -19,6 +18,7 @@ import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types' import type { UploadFileListModalRef } from '@/views/Conversation/types' import type { VariableConfigModalRef } from '@/views/Workflow/types' import type { Variable } from '@/views/Workflow/components/Properties/VariableList/types' +import { getFileInfoByUrl } from '@/api/fileStorage'; // Exposed methods via ref for parent components to access/set form state export interface ChatToolbarRef { @@ -31,7 +31,8 @@ export interface ChatToolbarRef { // Props for configuring toolbar features, upload settings, and event callbacks export interface ChatToolbarProps { features: FeaturesConfigForm - extra?: ReactNode + leftExtra?: ReactNode; + rightExtra?: ReactNode uploadAction?: string uploadRequestConfig?: { data?: Record @@ -52,7 +53,8 @@ interface FormValues { const max_file_count = 1; const ChatToolbar = forwardRef(({ features, - extra, + leftExtra, + rightExtra, uploadAction, uploadRequestConfig, onFilesChange, @@ -96,8 +98,6 @@ const ChatToolbar = forwardRef(({ } form.setFieldValue('files', [...lastFiles]) onFilesChange?.([...lastFiles]) - - console.log('lastFiles', lastFiles) } // Append recorded audio file to the file list and notify parent @@ -111,9 +111,33 @@ const ChatToolbar = forwardRef(({ // Merge a batch of files (e.g. from remote URL modal) into the file list const addFileList = (list?: any[]) => { if (!list?.length) return - const files = [...(queryValues?.files || []), ...list] + const uploadingList = list.map(f => ({ ...f, status: 'uploading' })) + const files = [...(queryValues?.files || []), ...uploadingList] form.setFieldValue('files', files) onFilesChange?.(files) + + uploadingList.forEach(file => { + getFileInfoByUrl(file.url) + .then((res) => { + const { file_name, file_size, content_type } = res as { file_name: string; file_size: number; content_type: string; } + const current: any[] = form.getFieldValue('files') || [] + const updated = current.map(f => f.uid === file.uid ? { + ...f, + status: 'done', + name: file_name, + size: file_size, + type: content_type, + } : f) + form.setFieldValue('files', updated) + onFilesChange?.(updated) + }) + .catch(() => { + const current: any[] = form.getFieldValue('files') || [] + const updated = current.map(f => f.uid === file.uid ? { ...f, status: 'error' } : f) + form.setFieldValue('files', updated) + onFilesChange?.(updated) + }) + }) } // Persist variable values from the config modal and notify parent @@ -163,28 +187,34 @@ const ChatToolbar = forwardRef(({ return (
- +