diff --git a/api/LICENSE b/LICENSE
similarity index 100%
rename from api/LICENSE
rename to LICENSE
diff --git a/api/app/celery_app.py b/api/app/celery_app.py
index 807c59f4..58c89f8f 100644
--- a/api/app/celery_app.py
+++ b/api/app/celery_app.py
@@ -1,5 +1,6 @@
import os
import platform
+import re
from datetime import timedelta
from urllib.parse import quote
@@ -11,21 +12,24 @@ from app.core.logging_config import get_logger
logger = get_logger(__name__)
+
+def _mask_url(url: str) -> str:
+ """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
+ return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
+
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
-# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
-# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
+# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
+# 未配置则回退到 Redis 方案
+# backend: 结果存储(使用 Redis)
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
-# Build canonical broker/backend URLs and force them into os.environ so that
-# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
-# cannot be overridden by stray env vars.
-# See: https://github.com/celery/celery/issues/4284
-_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
+_broker_url = os.getenv("CELERY_BROKER_URL") or \
+ f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
@@ -45,8 +49,8 @@ celery_app = Celery(
logger.info(
"Celery app initialized",
extra={
- "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
- "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
+ "broker": _mask_url(_broker_url),
+ "backend": _mask_url(_backend_url),
},
)
# Default queue for unrouted tasks
@@ -77,6 +81,7 @@ celery_app.conf.update(
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
+ worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
# 结果过期时间
result_expires=3600, # 结果保存1小时
diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py
index e9b539df..3ba9c3a9 100644
--- a/api/app/controllers/app_controller.py
+++ b/api/app/controllers/app_controller.py
@@ -57,6 +57,7 @@ def list_apps(
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
+ api_key: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
@@ -65,10 +66,25 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
+ - 当提供 api_key 参数时,查找该 API Key 关联的应用
"""
+ from sqlalchemy import select as sa_select
+ from app.models.api_key_model import ApiKey
+
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
+ # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
+ if api_key:
+ matched_id = db.execute(
+ sa_select(ApiKey.resource_id).where(
+ ApiKey.workspace_id == workspace_id,
+ ApiKey.api_key == api_key,
+ ApiKey.resource_id.isnot(None),
+ )
+ ).scalar_one_or_none()
+ ids = str(matched_id) if matched_id else ""
+
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py
index 55149cce..4e1ba74c 100644
--- a/api/app/controllers/file_storage_controller.py
+++ b/api/app/controllers/file_storage_controller.py
@@ -14,6 +14,9 @@ Routes:
import os
import uuid
from typing import Any
+import httpx
+import mimetypes
+from urllib.parse import urlparse, unquote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse
@@ -290,6 +293,101 @@ async def upload_file_with_share_token(
)
+@router.get("/files/info-by-url", response_model=ApiResponse)
+async def get_file_info_by_url(
+ url: str,
+):
+ """
+ Get file information by network URL (no authentication required).
+
+ Fetches file metadata from a remote URL via HTTP HEAD request.
+ Falls back to GET request if HEAD is not supported.
+ Returns file type, name, and size.
+
+ Args:
+ url: The network URL of the file.
+
+ Returns:
+ ApiResponse with file information.
+ """
+ api_logger.info(f"File info by URL request: url={url}")
+
+ try:
+ async with httpx.AsyncClient(timeout=10.0) as client:
+ # Try HEAD request first
+ response = await client.head(url, follow_redirects=True)
+
+ # If HEAD fails, try GET request (some servers don't support HEAD)
+ if response.status_code != 200:
+ api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
+ response = await client.get(url, follow_redirects=True)
+
+ if response.status_code != 200:
+ api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Unable to access file: HTTP {response.status_code}"
+ )
+
+ # Get file size from Content-Length header or actual content
+ file_size = response.headers.get("Content-Length")
+ if file_size:
+ file_size = int(file_size)
+ elif hasattr(response, 'content'):
+ file_size = len(response.content)
+ else:
+ file_size = None
+
+ # Get content type from Content-Type header
+ content_type = response.headers.get("Content-Type", "application/octet-stream")
+ # Remove charset and other parameters from content type
+ content_type = content_type.split(';')[0].strip()
+
+ # Extract filename from Content-Disposition or URL
+ file_name = None
+ content_disposition = response.headers.get("Content-Disposition")
+ if content_disposition and "filename=" in content_disposition:
+ parts = content_disposition.split("filename=")
+ if len(parts) > 1:
+ file_name = parts[1].strip('"').strip("'")
+
+ if not file_name:
+ parsed_url = urlparse(url)
+ file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
+
+ # Extract file extension from filename
+ _, file_ext = os.path.splitext(file_name)
+
+ # If no extension found, infer from content type
+ if not file_ext:
+ ext = mimetypes.guess_extension(content_type)
+ if ext:
+ file_ext = ext
+ file_name = f"{file_name}{file_ext}"
+
+ api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
+
+ return success(
+ data={
+ "url": url,
+ "file_name": file_name,
+ "file_ext": file_ext.lower() if file_ext else "",
+ "file_size": file_size,
+ "content_type": content_type,
+ },
+ msg="File information retrieved successfully"
+ )
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ api_logger.error(f"Unexpected error: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to retrieve file information: {str(e)}"
+ )
+
+
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
request: Request,
@@ -476,8 +574,12 @@ async def get_file_url(
# For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires)
else:
- # For remote storage (OSS/S3), get presigned URL
- url = await storage_service.get_file_url(file_key, expires=expires)
+ # For remote storage (OSS/S3), get presigned URL with forced download
+ url = await storage_service.get_file_url(
+ file_key,
+ expires=expires,
+ file_name=file_metadata.file_name,
+ )
url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}")
@@ -688,7 +790,7 @@ async def permanent_download_file(
# For remote storage, redirect to presigned URL with long expiration
try:
# Use a very long expiration (7 days max for most cloud providers)
- presigned_url = await storage_service.get_file_url(file_key, expires=604800)
+ presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e:
@@ -697,3 +799,44 @@ async def permanent_download_file(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
+
+
+@router.get("/files/{file_id}/status", response_model=ApiResponse)
+async def get_file_status(
+ file_id: uuid.UUID,
+ db: Session = Depends(get_db),
+):
+ """
+ Get file upload/processing status (no authentication required).
+
+ This endpoint is used to check if a file (e.g., TTS audio) is ready.
+ Returns status: pending, completed, or failed.
+
+ Args:
+ file_id: The UUID of the file.
+ db: Database session.
+
+ Returns:
+ ApiResponse with file status and metadata.
+ """
+ api_logger.info(f"File status request: file_id={file_id}")
+
+ # Query file metadata from database
+ file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
+ if not file_metadata:
+ api_logger.warning(f"File not found in database: file_id={file_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The file does not exist"
+ )
+
+ return success(
+ data={
+ "file_id": str(file_id),
+ "status": file_metadata.status,
+ "file_name": file_metadata.file_name,
+ "file_size": file_metadata.file_size,
+ "content_type": file_metadata.content_type,
+ },
+ msg="File status retrieved successfully"
+ )
diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py
index 0f2da3b0..6f27d87a 100644
--- a/api/app/controllers/mcp_market_config_controller.py
+++ b/api/app/controllers/mcp_market_config_controller.py
@@ -91,9 +91,11 @@ async def get_mcp_servers(
try:
cookies = api.get_cookies(token)
+ headers=api.builder_headers(api.headers)
+ headers['Authorization'] = f'Bearer {token}'
r = api.session.put(
url=api.mcp_base_url,
- headers=api.builder_headers(api.headers),
+ headers=headers,
json=body,
cookies=cookies)
raise_for_http_status(r)
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
+ headers['Authorization'] = f'Bearer {token}'
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token)
- r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
+ headers = api.builder_headers(api.headers)
+ headers['Authorization'] = f'Bearer {create_data.token}'
+ r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
'search': ""
}
cookies = api.get_cookies(token)
+ headers = api.builder_headers(api.headers)
+ headers['Authorization'] = f'Bearer {token}'
r = api.session.put(
url=api.mcp_base_url,
- headers=api.builder_headers(api.headers),
+ headers=headers,
json=body,
cookies=cookies)
raise_for_http_status(r)
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token)
- r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
+ headers = api.builder_headers(api.headers)
+ headers['Authorization'] = f'Bearer {update_data.token}'
+ r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py
index e3d2bf92..aa4d48e3 100644
--- a/api/app/controllers/memory_agent_controller.py
+++ b/api/app/controllers/memory_agent_controller.py
@@ -118,142 +118,142 @@ async def download_log(
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
-@router.post("/writer_service", response_model=ApiResponse)
-@cur_workspace_access_guard()
-async def write_server(
- user_input: Write_UserInput,
- language_type: str = Header(default=None, alias="X-Language-Type"),
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user)
-):
- """
- Write service endpoint - processes write operations synchronously
-
- Args:
- user_input: Write request containing message and end_user_id
- language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
-
- Returns:
- Response with write operation status
- """
- # 使用集中化的语言校验
- language = get_language_from_header(language_type)
-
- config_id = user_input.config_id
- workspace_id = current_user.current_workspace_id
- api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
-
- # 获取 storage_type,如果为 None 则使用默认值
- storage_type = workspace_service.get_workspace_storage_type(
- db=db,
- workspace_id=workspace_id,
- user=current_user
- )
- if storage_type is None: storage_type = 'neo4j'
- user_rag_memory_id = ''
-
- # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
- if storage_type == 'rag':
- if workspace_id:
- knowledge = knowledge_repository.get_knowledge_by_name(
- db=db,
- name="USER_RAG_MERORY",
- workspace_id=workspace_id
- )
- if knowledge:
- user_rag_memory_id = str(knowledge.id)
- else:
- api_logger.warning(
- f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
- storage_type = 'neo4j'
- else:
- api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
- storage_type = 'neo4j'
-
- api_logger.info(
- f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
- try:
- messages_list = memory_agent_service.get_messages_list(user_input)
- result = await memory_agent_service.write_memory(
- user_input.end_user_id,
- messages_list,
- config_id,
- db,
- storage_type,
- user_rag_memory_id,
- language
- )
-
- return success(data=result, msg="写入成功")
- except BaseException as e:
- # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
- if hasattr(e, 'exceptions'):
- error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
- detailed_error = "; ".join(error_messages)
- api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
- return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
- api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
- return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
-
-
-@router.post("/writer_service_async", response_model=ApiResponse)
-@cur_workspace_access_guard()
-async def write_server_async(
- user_input: Write_UserInput,
- language_type: str = Header(default=None, alias="X-Language-Type"),
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user)
-):
- """
- Async write service endpoint - enqueues write processing to Celery
-
- Args:
- user_input: Write request containing message and end_user_id
- language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
-
- Returns:
- Task ID for tracking async operation
- Use GET /memory/write_result/{task_id} to check task status and get result
- """
- # 使用集中化的语言校验
- language = get_language_from_header(language_type)
-
- config_id = user_input.config_id
- workspace_id = current_user.current_workspace_id
- api_logger.info(
- f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
-
- # 获取 storage_type,如果为 None 则使用默认值
- storage_type = workspace_service.get_workspace_storage_type(
- db=db,
- workspace_id=workspace_id,
- user=current_user
- )
- if storage_type is None: storage_type = 'neo4j'
- user_rag_memory_id = ''
- if workspace_id:
-
- knowledge = knowledge_repository.get_knowledge_by_name(
- db=db,
- name="USER_RAG_MERORY",
- workspace_id=workspace_id
- )
- if knowledge: user_rag_memory_id = str(knowledge.id)
- api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
- try:
- # 获取标准化的消息列表
- messages_list = memory_agent_service.get_messages_list(user_input)
-
- task = celery_app.send_task(
- "app.core.memory.agent.write_message",
- args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
- )
- api_logger.info(f"Write task queued: {task.id}")
-
- return success(data={"task_id": task.id}, msg="写入任务已提交")
- except Exception as e:
- api_logger.error(f"Async write operation failed: {str(e)}")
- return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
+# @router.post("/writer_service", response_model=ApiResponse)
+# @cur_workspace_access_guard()
+# async def write_server(
+# user_input: Write_UserInput,
+# language_type: str = Header(default=None, alias="X-Language-Type"),
+# db: Session = Depends(get_db),
+# current_user: User = Depends(get_current_user)
+# ):
+# """
+# Write service endpoint - processes write operations synchronously
+#
+# Args:
+# user_input: Write request containing message and end_user_id
+# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
+#
+# Returns:
+# Response with write operation status
+# """
+# # 使用集中化的语言校验
+# language = get_language_from_header(language_type)
+#
+# config_id = user_input.config_id
+# workspace_id = current_user.current_workspace_id
+# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
+#
+# # 获取 storage_type,如果为 None 则使用默认值
+# storage_type = workspace_service.get_workspace_storage_type(
+# db=db,
+# workspace_id=workspace_id,
+# user=current_user
+# )
+# if storage_type is None: storage_type = 'neo4j'
+# user_rag_memory_id = ''
+#
+# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
+# if storage_type == 'rag':
+# if workspace_id:
+# knowledge = knowledge_repository.get_knowledge_by_name(
+# db=db,
+# name="USER_RAG_MERORY",
+# workspace_id=workspace_id
+# )
+# if knowledge:
+# user_rag_memory_id = str(knowledge.id)
+# else:
+# api_logger.warning(
+# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
+# storage_type = 'neo4j'
+# else:
+# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
+# storage_type = 'neo4j'
+#
+# api_logger.info(
+# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
+# try:
+# messages_list = memory_agent_service.get_messages_list(user_input)
+# result = await memory_agent_service.write_memory(
+# user_input.end_user_id,
+# messages_list,
+# config_id,
+# db,
+# storage_type,
+# user_rag_memory_id,
+# language
+# )
+#
+# return success(data=result, msg="写入成功")
+# except BaseException as e:
+# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
+# if hasattr(e, 'exceptions'):
+# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
+# detailed_error = "; ".join(error_messages)
+# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
+# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
+# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
+# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
+#
+#
+# @router.post("/writer_service_async", response_model=ApiResponse)
+# @cur_workspace_access_guard()
+# async def write_server_async(
+# user_input: Write_UserInput,
+# language_type: str = Header(default=None, alias="X-Language-Type"),
+# db: Session = Depends(get_db),
+# current_user: User = Depends(get_current_user)
+# ):
+# """
+# Async write service endpoint - enqueues write processing to Celery
+#
+# Args:
+# user_input: Write request containing message and end_user_id
+# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
+#
+# Returns:
+# Task ID for tracking async operation
+# Use GET /memory/write_result/{task_id} to check task status and get result
+# """
+# # 使用集中化的语言校验
+# language = get_language_from_header(language_type)
+#
+# config_id = user_input.config_id
+# workspace_id = current_user.current_workspace_id
+# api_logger.info(
+# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
+#
+# # 获取 storage_type,如果为 None 则使用默认值
+# storage_type = workspace_service.get_workspace_storage_type(
+# db=db,
+# workspace_id=workspace_id,
+# user=current_user
+# )
+# if storage_type is None: storage_type = 'neo4j'
+# user_rag_memory_id = ''
+# if workspace_id:
+#
+# knowledge = knowledge_repository.get_knowledge_by_name(
+# db=db,
+# name="USER_RAG_MERORY",
+# workspace_id=workspace_id
+# )
+# if knowledge: user_rag_memory_id = str(knowledge.id)
+# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
+# try:
+# # 获取标准化的消息列表
+# messages_list = memory_agent_service.get_messages_list(user_input)
+#
+# task = celery_app.send_task(
+# "app.core.memory.agent.write_message",
+# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
+# )
+# api_logger.info(f"Write task queued: {task.id}")
+#
+# return success(data={"task_id": task.id}, msg="写入任务已提交")
+# except Exception as e:
+# api_logger.error(f"Async write operation failed: {str(e)}")
+# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse)
diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py
index cc0efab3..fe4337d1 100644
--- a/api/app/controllers/memory_dashboard_controller.py
+++ b/api/app/controllers/memory_dashboard_controller.py
@@ -663,9 +663,12 @@ async def dashboard_data(
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
- from app.repositories import app_repository
- apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
- rag_data["total_app"] = len(apps_orm)
+ # 包含自有app + 被分享给本工作空间的app
+ from app.services import app_service as _app_svc
+ _, total_app = _app_svc.AppService(db).list_apps(
+ workspace_id=workspace_id, include_shared=True, pagesize=1
+ )
+ rag_data["total_app"] = total_app
# total_knowledge: 使用 total_kb(总知识库数)
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
@@ -687,7 +690,7 @@ async def dashboard_data(
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
- api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
+ api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py
index d91dfc36..d8b39325 100644
--- a/api/app/controllers/memory_storage_controller.py
+++ b/api/app/controllers/memory_storage_controller.py
@@ -54,8 +54,8 @@ router = APIRouter(
@router.get("/info", response_model=ApiResponse)
async def get_storage_info(
- storage_id: str,
- current_user: User = Depends(get_current_user)
+ storage_id: str,
+ current_user: User = Depends(get_current_user)
):
"""
Example wrapper endpoint - retrieves storage information
@@ -75,24 +75,19 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
-
-
-
-
-
-@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
+@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config(
- payload: ConfigParamsCreate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
- x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
+ payload: ConfigParamsCreate,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+ x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
try:
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
@@ -107,9 +102,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
- msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
+ msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
+ f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
- msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
+ msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
+ f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -119,9 +116,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
- msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
+ msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
+ f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
- msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
+ msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
+ f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -129,10 +128,10 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
- config_id: UUID|int,
- force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ config_id: UUID | int,
+ force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
"""删除记忆配置(带终端用户保护)
@@ -145,24 +144,24 @@ def delete_config(
force: 设置为 true 可强制删除(即使有终端用户正在使用)
"""
workspace_id = current_user.current_workspace_id
- config_id=resolve_config_id(config_id, db)
+ config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}"
)
-
+
try:
# 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService
-
+
config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force)
-
+
if result["status"] == "error":
api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
@@ -172,7 +171,7 @@ def delete_config(
msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
)
-
+
if result["status"] == "warning":
api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, "
@@ -186,7 +185,7 @@ def delete_config(
"force_required": result["force_required"]
}
)
-
+
api_logger.info(
f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}"
@@ -195,7 +194,7 @@ def delete_config(
msg=result["message"],
data={"affected_users": result["affected_users"]}
)
-
+
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@@ -203,9 +202,9 @@ def delete_config(
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
- payload: ConfigUpdate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ payload: ConfigUpdate,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
@@ -213,12 +212,13 @@ def update_config(
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
- return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
-
+ return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
+ "config_name, config_desc, scene_id 均为空")
+
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)
@@ -231,9 +231,9 @@ def update_config(
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted(
- payload: ConfigUpdateExtracted,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ payload: ConfigUpdateExtracted,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
@@ -241,7 +241,7 @@ def update_config_extracted(
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
try:
svc = DataConfigService(db)
@@ -256,11 +256,11 @@ def update_config_extracted(
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
-@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
+@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
- config_id: UUID | int,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ config_id: UUID | int,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db)
@@ -268,7 +268,7 @@ def read_config_extracted(
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
try:
svc = DataConfigService(db)
@@ -278,18 +278,19 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
-@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
+
+@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config(
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
-
+
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
-
+
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
try:
svc = DataConfigService(db)
@@ -303,14 +304,14 @@ def read_all_config(
@router.post("/pilot_run", response_model=None)
async def pilot_run(
- payload: ConfigPilotRun,
- language_type: str = Header(default=None, alias="X-Language-Type"),
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ payload: ConfigPilotRun,
+ language_type: str = Header(default=None, alias="X-Language-Type"),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> StreamingResponse:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
-
+
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, "
@@ -333,9 +334,9 @@ async def pilot_run(
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try:
result = await kb_type_distribution(end_user_id)
@@ -344,12 +345,12 @@ async def get_kb_type_distribution(
api_logger.error(f"KB type distribution failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
-
+
@router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try:
result = await search_dialogue(end_user_id)
@@ -361,9 +362,9 @@ async def search_dialogues_num(
@router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try:
result = await search_chunk(end_user_id)
@@ -375,9 +376,9 @@ async def search_chunks_num(
@router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try:
result = await search_statement(end_user_id)
@@ -389,9 +390,9 @@ async def search_statements_num(
@router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try:
result = await search_entity(end_user_id)
@@ -403,9 +404,9 @@ async def search_entities_num(
@router.get("/search", response_model=ApiResponse)
async def search_all_num(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
result = await search_all(end_user_id)
@@ -417,9 +418,9 @@ async def search_all_num(
@router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try:
result = await search_detials(end_user_id)
@@ -431,9 +432,9 @@ async def search_entities_detials(
@router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges(
- end_user_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ end_user_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try:
result = await search_edges(end_user_id)
@@ -443,14 +444,12 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
-
-
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api(
- limit: int = 10,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ limit: int = 10,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+) -> dict:
"""
获取热门记忆标签(带Redis缓存)
@@ -461,18 +460,18 @@ async def get_hot_memory_tags_api(
- 缓存未命中:~600-800ms(取决于LLM速度)
"""
workspace_id = current_user.current_workspace_id
-
+
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
-
+
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
-
+
try:
# 尝试从Redis缓存获取
import json
from app.aioRedis import aio_redis_get, aio_redis_set
-
+
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
@@ -481,11 +480,11 @@ async def get_hot_memory_tags_api(
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
-
+
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit)
-
+
# 写入缓存(过期时间:5分钟)
# 注意:result是列表,需要转换为JSON字符串
try:
@@ -495,9 +494,9 @@ async def get_hot_memory_tags_api(
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
-
+
return success(data=result, msg="查询成功")
-
+
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
- current_user: User = Depends(get_current_user),
- ) -> dict:
+ current_user: User = Depends(get_current_user),
+) -> dict:
"""
清除热门标签缓存
@@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache(
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
-
+
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
-
+
try:
from app.aioRedis import aio_redis_delete
-
+
# 清除所有limit的缓存(常见的limit值)
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
@@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache(
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
-
+
return success(
- data={"cleared_count": cleared_count},
+ data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
-
+
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
- current_user: User = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
-
diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py
index 6204a745..71fd41ad 100644
--- a/api/app/controllers/model_controller.py
+++ b/api/app/controllers/model_controller.py
@@ -42,6 +42,7 @@ def get_model_strategies():
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
+ capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -74,10 +75,21 @@ def get_model_list(
unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type]
+ capability_list = []
+ if capability is not None:
+ flat_capability = []
+ for item in capability:
+ split_items = [c.strip() for c in item.split(', ') if c.strip()]
+ flat_capability.extend(split_items)
+
+ unique_flat_capability = list(dict.fromkeys(flat_capability))
+ capability_list = unique_flat_capability
+
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type_list,
provider=provider,
+ capability=capability_list,
is_active=is_active,
is_public=is_public,
search=search,
diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py
index 33d7b60c..f5284b46 100644
--- a/api/app/controllers/public_share_controller.py
+++ b/api/app/controllers/public_share_controller.py
@@ -669,6 +669,7 @@ async def config_query(
content = {
"app_type": release.app.type,
"variables": release.config.get("variables"),
+ "memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features")
}
elif release.app.type == AppType.MULTI_AGENT:
diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py
index be796ff9..3ce1df6e 100644
--- a/api/app/controllers/user_memory_controllers.py
+++ b/api/app/controllers/user_memory_controllers.py
@@ -5,7 +5,7 @@
from typing import Optional
import datetime
from sqlalchemy.orm import Session
-from fastapi import APIRouter, Depends,Header
+from fastapi import APIRouter, Depends, Header
from app.db import get_db
from app.core.language_utils import get_language_from_header
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
analytics_graph_data,
analytics_community_graph_data,
)
-from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
+from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository
@@ -45,9 +45,9 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
- end_user_id: str,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
"""
获取缓存的记忆洞察报告
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
- end_user_id: str,
- language_type: str = Header(default=None, alias="X-Language-Type"),
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ language_type: str = Header(default=None, alias="X-Language-Type"),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
"""
获取缓存的用户摘要
@@ -90,7 +90,7 @@ async def get_user_summary_api(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
-
+
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -102,7 +102,7 @@ async def get_user_summary_api(
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
- result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
+ result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -117,10 +117,10 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api(
- request: GenerateCacheRequest,
- language_type: str = Header(default=None, alias="X-Language-Type"),
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ request: GenerateCacheRequest,
+ language_type: str = Header(default=None, alias="X-Language-Type"),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
"""
手动触发缓存生成
@@ -134,7 +134,7 @@ async def generate_cache_api(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
-
+
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -155,10 +155,12 @@ async def generate_cache_api(
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察
- insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
+ insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
+ language=language)
# 生成用户摘要
- summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
+ summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
+ language=language)
# 构建响应
result = {
@@ -209,9 +211,9 @@ async def generate_cache_api(
@router.get("/analytics/node_statistics", response_model=ApiResponse)
async def get_node_statistics_api(
- end_user_id: str,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
- api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
+ api_logger.info(
+ f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try:
# 调用新的记忆类型统计函数
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
# 计算总数用于日志
total_count = sum(item["count"] for item in result)
- api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
+ api_logger.info(
+ f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
+
@router.get("/analytics/graph_data", response_model=ApiResponse)
async def get_graph_data_api(
- end_user_id: str,
- node_types: Optional[str] = None,
- limit: int = 100,
- depth: int = 1,
- center_node_id: Optional[str] = None,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ node_types: Optional[str] = None,
+ limit: int = 100,
+ depth: int = 1,
+ center_node_id: Optional[str] = None,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
@@ -298,9 +303,9 @@ async def get_graph_data_api(
@router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api(
- end_user_id: str,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
@router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_profile(
- end_user_id: str,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ end_user_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
@@ -385,9 +390,9 @@ async def get_end_user_profile(
@router.post("/updated_end_user/profile", response_model=ApiResponse)
async def update_end_user_profile(
- profile_update: EndUserProfileUpdate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
+ profile_update: EndUserProfileUpdate,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
) -> dict:
"""
更新终端用户的基本信息
@@ -417,7 +422,7 @@ async def update_end_user_profile(
else:
error_msg = result["error"]
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
-
+
# 根据错误类型映射到合适的业务错误码
if error_msg == "终端用户不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
@@ -427,15 +432,18 @@ async def update_end_user_profile(
# 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
+
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
-async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
- ):
+async def memory_space_timeline_of_shared_memories(
+ id: str, label: str,
+ language_type: str = Header(default=None, alias="X-Language-Type"),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+):
# 使用集中化的语言校验
language = get_language_from_header(language_type)
-
- workspace_id=current_user.current_workspace_id
+
+ workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
return success(data=timeline_memories_result, msg="共同记忆时间线")
+
+
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db),
- ):
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+ ):
try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py
index 88b6371c..464a668a 100644
--- a/api/app/core/agent/langchain_agent.py
+++ b/api/app/core/agent/langchain_agent.py
@@ -598,8 +598,10 @@ class LangChainAgent:
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
- total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
- 0) if response_meta else 0
+ total_tokens = response_meta.get("token_usage", {}).get(
+ "total_tokens",
+ 0
+ ) if response_meta else 0
yield total_tokens
break
if memory_flag:
diff --git a/api/app/core/config.py b/api/app/core/config.py
index 4a944557..64c5520e 100644
--- a/api/app/core/config.py
+++ b/api/app/core/config.py
@@ -231,8 +231,8 @@ class Settings:
# Celery configuration (internal)
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md
- # 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
- # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
+ # 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
+ # 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
diff --git a/api/app/core/logging_config.py b/api/app/core/logging_config.py
index 28a98a46..d0dda84b 100644
--- a/api/app/core/logging_config.py
+++ b/api/app/core/logging_config.py
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
# Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}")
- # Always print to console (backward compatible behavior)
- print(f"✓ {step_name}: {duration:.2f}s")
+ # Always log at INFO level (avoids Celery treating stdout as WARNING)
+ _timing_logger = logging.getLogger(__name__)
+ _timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service",
diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
index 6176caf5..2074b6ca 100644
--- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
+++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
- formatted_messages = (redis_messages)
+ formatted_messages = redis_messages
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py
index 3b06defe..4c667061 100644
--- a/api/app/core/memory/agent/utils/get_dialogs.py
+++ b/api/app/core/memory/agent/utils/get_dialogs.py
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1",
messages: list = None,
- ref_id: str = "wyl_20251027",
+ ref_id: str = "",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy.
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
role = msg['role']
content = msg['content']
+ files = msg.get("file_content", [])
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
- conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
+ conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py
index b62eb50a..6829cf57 100644
--- a/api/app/core/memory/agent/utils/write_tools.py
+++ b/api/app/core/memory/agent/utils/write_tools.py
@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import asyncio
import time
+import uuid
from datetime import datetime
from dotenv import load_dotenv
@@ -13,28 +14,28 @@ from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
-from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
+from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
+ memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
-from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
+from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
-
load_dotenv()
logger = get_agent_logger(__name__)
async def write(
- end_user_id: str,
- memory_config: MemoryConfig,
- messages: list,
- ref_id: str = "wyl20251027",
- language: str = "zh",
+ end_user_id: str,
+ memory_config: MemoryConfig,
+ messages: list,
+ ref_id: str = "",
+ language: str = "zh",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
@@ -43,9 +44,11 @@ async def write(
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
- ref_id: Reference ID, defaults to "wyl20251027"
+ ref_id: Reference ID, defaults to ""
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
"""
+ if not ref_id:
+ ref_id = uuid.uuid4().hex
# Extract config values
embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy
@@ -99,14 +102,14 @@ async def write(
if memory_config.scene_id:
try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
-
+
with get_db_context() as db:
ontology_types = load_ontology_types_for_scene(
scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id,
db=db
)
-
+
if ontology_types:
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
@@ -135,9 +138,11 @@ async def write(
all_chunk_nodes,
all_statement_nodes,
all_entity_nodes,
+ all_perceptual_nodes,
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
+ all_perceptual_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
@@ -162,18 +167,21 @@ async def write(
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
+ perceptual_nodes=all_perceptual_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
+ perceptual_edges=all_perceptual_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
- # 写入成功后,异步触发聚类(不阻塞写入响应)
- schedule_clustering_after_write(
+ # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
+ await _trigger_clustering_sync(
all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
- embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
+ embedding_model_id=str(
+ memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break
else:
@@ -208,9 +216,8 @@ async def write(
summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
)
-
+ ms_connector = Neo4jConnector()
try:
- ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
@@ -251,4 +258,4 @@ async def write(
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("=== Pipeline Complete ===")
- logger.info(f"Total execution time: {total_time:.2f} seconds")
\ No newline at end of file
+ logger.info(f"Total execution time: {total_time:.2f} seconds")
diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py
index 93a2df82..51d15aab 100644
--- a/api/app/core/memory/llm_tools/chunker_client.py
+++ b/api/app/core/memory/llm_tools/chunker_client.py
@@ -1,10 +1,10 @@
-from typing import Any, List
-import re
-import os
import asyncio
import json
-import numpy as np
import logging
+import os
+from typing import Any, List
+
+import numpy as np
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -246,6 +246,7 @@ class ChunkerClient:
"total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
+ files=msg.files
)
dialogue.chunks.append(chunk)
else:
@@ -258,6 +259,7 @@ class ChunkerClient:
"message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy,
},
+ files=msg.files
)
dialogue.chunks.append(chunk)
diff --git a/api/app/core/memory/llm_tools/openai_embedder.py b/api/app/core/memory/llm_tools/openai_embedder.py
index 2d6fccbc..6ae87887 100644
--- a/api/app/core/memory/llm_tools/openai_embedder.py
+++ b/api/app/core/memory/llm_tools/openai_embedder.py
@@ -2,6 +2,7 @@
OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
+自动支持火山引擎的多模态 Embedding。
"""
from typing import List
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
)
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
+from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__)
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
- 批量文本嵌入
- 自动重试机制
- 错误处理
+ - 火山引擎多模态 Embedding(自动识别)
"""
def __init__(self, model_config: RedBearModelConfig):
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
"""
super().__init__(model_config)
- # 初始化 RedBearEmbeddings 模型
+ # 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
self.model = RedBearEmbeddings(
RedBearModelConfig(
model_name=self.model_name,
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
timeout=self.timeout,
)
)
+ self.is_multimodal = self.model.is_multimodal_supported()
- logger.info("OpenAI Embedder 客户端初始化完成")
+ logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
async def response(
self,
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
return []
# 生成嵌入向量
- embeddings = await self.model.aembed_documents(texts)
+ if self.is_multimodal:
+ # 火山引擎多模态 Embedding
+ embeddings = await self.model.aembed_multimodal(
+ [{"type": "text", "text": text} for text in texts]
+ )
+ else:
+ # 普通 Embedding
+ embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings
diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py
index 1880b9ab..1b8c9d52 100644
--- a/api/app/core/memory/models/graph_models.py
+++ b/api/app/core/memory/models/graph_models.py
@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
"""
if v is None:
return v
-
+
# 处理 Neo4j DateTime 对象
if hasattr(v, 'to_native'):
return v.to_native()
-
+
# 处理 Python datetime 对象
if isinstance(v, datetime):
return v
-
+
if isinstance(v, str):
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
# 支持1-4位年份
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
match = re.match(pattern, v)
-
+
if match:
try:
year = int(match.group(1))
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
minute = int(match.group(5)) if match.group(5) else 0
second = int(match.group(6)) if match.group(6) else 0
microsecond = 0
-
+
# 处理微秒
if match.group(7):
# 补齐或截断到6位
us_str = match.group(7).ljust(6, '0')[:6]
microsecond = int(us_str)
-
+
# 处理时区
tzinfo = None
if 'Z' in v or match.group(8):
tzinfo = timezone.utc
-
+
# 创建 datetime 对象
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
-
+
except (ValueError, OverflowError):
# 日期值无效(如月份13、日期32等)
return None
-
+
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
try:
return datetime.fromisoformat(v.replace('Z', '+00:00'))
except Exception:
return None
-
+
return v
@@ -114,7 +114,7 @@ class Edge(BaseModel):
end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
- expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
+ expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge):
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
-
+
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
return parse_historical_datetime(v)
+class PerceptualEdge(Edge):
+ """Edge connecting perceptual nodes to their source chunks
+ """
+ pass
+
+
class Node(BaseModel):
"""Base class for all graph nodes in the knowledge graph.
@@ -206,7 +212,8 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
- config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
+ config_id: Optional[int | str] = Field(None,
+ description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node):
@@ -241,17 +248,17 @@ class StatementNode(Node):
chunk_id: str = Field(..., description="ID of the parent chunk")
stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content")
-
+
# Speaker identification
speaker: Optional[str] = Field(
None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
)
-
+
# Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field(
- None,
- ge=0.0,
+ None,
+ ge=0.0,
le=1.0,
description="Emotion intensity: 0.0-1.0 (displayed on node)"
)
@@ -264,25 +271,26 @@ class StatementNode(Node):
description="Emotion subject: self/other/object"
)
emotion_type: Optional[str] = Field(
- None,
+ None,
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
)
emotion_keywords: Optional[List[str]] = Field(
default_factory=list,
description="Emotion keywords list, max 3 items"
)
-
+
# Temporal fields
temporal_info: TemporalInfo = Field(..., description="Temporal information")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
-
+
# Embedding and other fields
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
- config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
-
+ config_id: Optional[int | str] = Field(None,
+ description="Configuration ID used to process this statement (integer or string)")
+
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
@@ -309,13 +317,13 @@ class StatementNode(Node):
ge=0,
description="Total number of times this node has been accessed"
)
-
+
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
"""使用通用的历史日期解析函数"""
return parse_historical_datetime(v)
-
+
@field_validator('emotion_type', mode='before')
@classmethod
def validate_emotion_type(cls, v):
@@ -326,7 +334,7 @@ class StatementNode(Node):
if v not in valid_types:
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
return v
-
+
@field_validator('emotion_subject', mode='before')
@classmethod
def validate_emotion_subject(cls, v):
@@ -337,7 +345,7 @@ class StatementNode(Node):
if v not in valid_subjects:
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
return v
-
+
@field_validator('emotion_keywords', mode='before')
@classmethod
def validate_emotion_keywords(cls, v):
@@ -405,19 +413,20 @@ class ExtractedEntityNode(Node):
entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description")
example: str = Field(
- default="",
+ default="",
description="A concise example (around 20 characters) to help understand the entity"
)
aliases: List[str] = Field(
- default_factory=list,
+ default_factory=list,
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
- config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
-
+ config_id: Optional[int | str] = Field(None,
+ description="Configuration ID used to process this entity (integer or string)")
+
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
@@ -444,16 +453,16 @@ class ExtractedEntityNode(Node):
ge=0,
description="Total number of times this node has been accessed"
)
-
+
# Explicit Memory Classification
is_explicit_memory: bool = Field(
default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
)
-
+
@field_validator('aliases', mode='before')
@classmethod
- def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
+ def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
"""Validate and clean aliases field using utility function.
This validator ensures that the aliases field is always a valid list of strings.
@@ -507,8 +516,9 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
- config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
-
+ config_id: Optional[int | str] = Field(None,
+ description="Configuration ID used to process this summary (integer or string)")
+
# ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field(
None,
@@ -522,7 +532,7 @@ class MemorySummaryNode(Node):
None,
description="Timestamp when the nodes were merged"
)
-
+
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)"
)
+
+
+class PerceptualNode(Node):
+ """Node representing a multimodal message in the knowledge graph.
+ """
+ perceptual_type: int
+ file_path: str
+ file_name: str
+ file_ext: str
+ summary: str
+ keywords: list[str]
+ topic: str
+ domain: str
+ file_type: str
+ summary_embedding: list[float] | None
diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py
index 2f8660af..66203067 100644
--- a/api/app/core/memory/models/message_models.py
+++ b/api/app/core/memory/models/message_models.py
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
"""
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.")
+ files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
class TemporalValidityRange(BaseModel):
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
- chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
+ files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
+ chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod
diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py
index d9c04f8b..0fa6a833 100644
--- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py
+++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py
@@ -71,13 +71,11 @@ class LabelPropagationEngine:
connector: Neo4jConnector,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
- embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
- self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
+ await self._generate_community_metadata([new_cid], end_user_id)
return
# 统计邻居社区分布
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
- await self._generate_community_metadata([target_cid], end_user_id)
+ # 新实体加入后成员变化,强制重新生成元数据
+ await self._generate_community_metadata([target_cid], end_user_id, force=True)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
return lines
async def _generate_community_metadata(
- self, community_ids: List[str], end_user_id: str
+ self, community_ids: List[str], end_user_id: str, force: bool = False
) -> None:
"""
为一个或多个社区生成并写入元数据。
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
1. 逐个社区调 LLM 生成 name / summary(串行)
2. 收集所有 summary,一次性批量 embed
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
- """
- if not community_ids:
- return
+ Args:
+ force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
+ """
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
- # --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
- async def _build_one(cid: str):
- members = await self.repo.get_community_members(cid, end_user_id)
- if not members:
+ async def _build_one(cid: str) -> Optional[Dict]:
+ try:
+ if not force:
+ check_embedding = bool(self.embedding_model_id)
+ if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
+ return None
+
+ members = await self.repo.get_community_members(cid, end_user_id)
+ if not members:
+ logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
+ return None
+
+ sorted_members = sorted(
+ members,
+ key=lambda m: m.get("activation_value") or 0,
+ reverse=True,
+ )
+ core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
+ all_names = [m["name"] for m in members if m.get("name")]
+
+ name = "、".join(core_entities[:3]) if core_entities else cid[:8]
+ summary = f"包含实体:{', '.join(all_names)}"
+
+ if self.llm_model_id:
+ try:
+ entity_list_str = "\n".join(self._build_entity_lines(members))
+ relationships = await self.repo.get_community_relationships(cid, end_user_id)
+ rel_lines = [
+ f"- {r['subject']} → {r['predicate']} → {r['object']}"
+ for r in relationships
+ if r.get("subject") and r.get("predicate") and r.get("object")
+ ]
+ rel_section = (
+ f"\n实体间关系:\n" + "\n".join(rel_lines)
+ if rel_lines else ""
+ )
+ prompt = (
+ f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
+ f"请为这组实体所代表的主题:\n"
+ f"1. 起一个简洁的中文名称(不超过10个字)\n"
+ f"2. 写一句话摘要(不超过80个字)\n\n"
+ f"严格按以下格式输出,不要有其他内容:\n"
+ f"名称:<名称>\n摘要:<摘要>"
+ )
+ with get_db_context() as db:
+ llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
+ response = await llm_client.chat([{"role": "user", "content": prompt}])
+ text = response.content if hasattr(response, "content") else str(response)
+
+ for line in text.strip().splitlines():
+ if line.startswith("名称:"):
+ name = line[3:].strip()
+ elif line.startswith("摘要:"):
+ summary = line[3:].strip()
+ except Exception as e:
+ logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
+
+ return {
+ "community_id": cid,
+ "end_user_id": end_user_id,
+ "name": name,
+ "summary": summary,
+ "core_entities": core_entities,
+ "summary_embedding": None,
+ }
+ except Exception as e:
+ logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
return None
- sorted_members = sorted(
- members,
- key=lambda m: m.get("activation_value") or 0,
- reverse=True,
- )
- core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
-
- entity_list_str = "\n".join(self._build_entity_lines(members))
-
- # 方案四:注入社区内实体间关系三元组
- relationships = await self.repo.get_community_relationships(cid, end_user_id)
- rel_lines = [
- f"- {r['subject']} → {r['predicate']} → {r['object']}"
- for r in relationships
- if r.get("subject") and r.get("predicate") and r.get("object")
- ]
- rel_section = (
- f"\n实体间关系:\n" + "\n".join(rel_lines)
- if rel_lines else ""
- )
-
- prompt = (
- f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
- f"请为这组实体所代表的主题:\n"
- f"1. 起一个简洁的中文名称(不超过10个字)\n"
- f"2. 写一句话摘要(不超过80个字)\n\n"
- f"严格按以下格式输出,不要有其他内容:\n"
- f"名称:<名称>\n摘要:<摘要>"
- )
- with get_db_context() as db:
- llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
- response = await llm_client.chat([{"role": "user", "content": prompt}])
- text = response.content if hasattr(response, "content") else str(response)
-
- name, summary = "", ""
- for line in text.strip().splitlines():
- if line.startswith("名称:"):
- name = line[3:].strip()
- elif line.startswith("摘要:"):
- summary = line[3:].strip()
-
- return {
- "community_id": cid,
- "end_user_id": end_user_id,
- "name": name,
- "summary": summary,
- "core_entities": core_entities,
- "summary_embedding": None,
- }
-
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
return_exceptions=True,
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
metadata_list.append(res)
if not metadata_list:
+ logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
return
# --- 阶段2:批量生成 summary_embedding ---
- summaries = [m["summary"] for m in metadata_list]
- with get_db_context() as db:
- embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
- embeddings = await embedder.response(summaries)
- for i, meta in enumerate(metadata_list):
- meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
+ if self.embedding_model_id:
+ try:
+ summaries = [m["summary"] for m in metadata_list]
+ with get_db_context() as db:
+ embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
+ embeddings = await embedder.response(summaries)
+ for i, meta in enumerate(metadata_list):
+ meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
+ except Exception as e:
+ logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
# --- 阶段3:写入(单个 or 批量)---
if len(metadata_list) == 1:
@@ -558,17 +576,13 @@ class LabelPropagationEngine:
core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"],
)
- if result:
- logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
- else:
- logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
+ if not result:
+ logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
else:
ok = await self.repo.batch_update_community_metadata(metadata_list)
- if ok:
- logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
- else:
- logger.warning(f"[Clustering] 批量写入社区元数据失败")
+ if not ok:
+ logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
@staticmethod
def _new_community_id() -> str:
- return str(uuid.uuid4())
+ return str(uuid.uuid4())
\ No newline at end of file
diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py
index 248067e7..967f529e 100644
--- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py
+++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py
@@ -9,6 +9,7 @@
"""
import asyncio
+import logging
import os
import hashlib
import json
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
ScenePatterns
)
+logger = logging.getLogger(__name__)
+
class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -706,7 +709,7 @@ class SemanticPruner:
# 阈值保护:最高0.9
proportion = float(self.config.pruning_threshold)
if proportion > 0.9:
- print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
+ logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
@@ -905,7 +908,7 @@ class SemanticPruner:
# Safety: avoid empty dataset
if not result:
- print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
+ logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
@@ -915,8 +918,7 @@ class SemanticPruner:
try:
self.run_logs.append(msg)
except Exception:
- # 任何异常都不影响打印
pass
- print(msg)
+ logger.debug(msg)
diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py
index f28b8a5f..4b9c5718 100644
--- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py
+++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return(
- dialogue_nodes: List[DialogueNode],
- chunk_nodes: List[ChunkNode],
- statement_nodes: List[StatementNode],
- entity_nodes: List[ExtractedEntityNode],
- statement_chunk_edges: List[StatementChunkEdge],
- statement_entity_edges: List[StatementEntityEdge],
- entity_entity_edges: List[EntityEntityEdge],
- dialog_data_list: List[DialogData],
- pipeline_config: ExtractionPipelineConfig,
- connector: Optional[Neo4jConnector] = None,
- llm_client = None,
+ dialogue_nodes: List[DialogueNode],
+ chunk_nodes: List[ChunkNode],
+ statement_nodes: List[StatementNode],
+ entity_nodes: List[ExtractedEntityNode],
+ statement_chunk_edges: List[StatementChunkEdge],
+ statement_entity_edges: List[StatementEntityEdge],
+ entity_entity_edges: List[EntityEntityEdge],
+ dialog_data_list: List[DialogData],
+ pipeline_config: ExtractionPipelineConfig,
+ connector: Optional[Neo4jConnector] = None,
+ llm_client=None,
) -> Tuple[
List[DialogueNode],
List[ChunkNode],
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge],
List[StatementEntityEdge],
List[EntityEntityEdge],
- dict, # 新增:返回去重详情
+ dict
]:
"""
执行两层实体去重与融合:
diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py
index 00d06f72..e0b86d8c 100644
--- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py
+++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py
@@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import (
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
+ PerceptualEdge,
+ PerceptualNode
)
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
-from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
)
@@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
embedding_generation,
generate_entity_embeddings_from_triplets,
)
-
# 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor,
@@ -90,16 +90,16 @@ class ExtractionOrchestrator:
"""
def __init__(
- self,
- llm_client: LLMClient,
- embedder_client: OpenAIEmbedderClient,
- connector: Neo4jConnector,
- config: Optional[ExtractionPipelineConfig] = None,
- progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
- embedding_id: Optional[str] = None,
- ontology_types: Optional[OntologyTypeList] = None,
- enable_general_types: bool = True,
- language: str = "zh",
+ self,
+ llm_client: LLMClient,
+ embedder_client: OpenAIEmbedderClient,
+ connector: Neo4jConnector,
+ config: Optional[ExtractionPipelineConfig] = None,
+ progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
+ embedding_id: Optional[str] = None,
+ ontology_types: Optional[OntologyTypeList] = None,
+ enable_general_types: bool = True,
+ language: str = "zh",
):
"""
初始化流水线编排器
@@ -123,7 +123,7 @@ class ExtractionOrchestrator:
self.progress_callback = progress_callback # 保存进度回调函数
self.embedding_id = embedding_id # 保存嵌入模型ID
self.language = language # 保存语言配置
-
+
# 处理本体类型配置
# 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并
# 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合
@@ -146,7 +146,7 @@ class ExtractionOrchestrator:
self.ontology_types = ontology_types
if not enable_general_types and ontology_types:
logger.info("enable_general_types=False,仅使用场景类型")
-
+
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
@@ -157,19 +157,27 @@ class ExtractionOrchestrator:
llm_client=llm_client,
config=self.config.statement_extraction,
)
- self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
+ self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
+ language=language)
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
logger.info("ExtractionOrchestrator 初始化完成")
async def run(
- self,
- dialog_data_list: List[DialogData],
- is_pilot_run: bool = False,
- ) -> Tuple[
- Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
- Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
- Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
+ self,
+ dialog_data_list: List[DialogData],
+ is_pilot_run: bool = False,
+ ) -> tuple[
+ list[DialogueNode],
+ list[ChunkNode],
+ list[StatementNode],
+ list[ExtractedEntityNode],
+ list[PerceptualNode],
+ list[StatementChunkEdge],
+ list[StatementEntityEdge],
+ list[EntityEntityEdge],
+ list[PerceptualEdge],
+ dict
]:
"""
运行完整的知识提取流水线(优化版:并行执行)
@@ -202,13 +210,12 @@ class ExtractionOrchestrator:
# 步骤 1: 陈述句提取
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
dialog_data_list = await self._extract_statements(dialog_data_list)
-
+
# 收集陈述句内容和统计数量
all_statements_list = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements)
- len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -220,7 +227,7 @@ class ExtractionOrchestrator:
chunk_embedding_maps,
dialog_embeddings,
) = await self._parallel_extract_and_embed(dialog_data_list)
-
+
# 收集实体和三元组内容,并统计数量
all_entities_list = []
all_triplets_list = []
@@ -229,10 +236,6 @@ class ExtractionOrchestrator:
if triplet_info:
all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets)
-
- len(all_entities_list)
- len(all_triplets_list)
- sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入")
@@ -252,17 +255,19 @@ class ExtractionOrchestrator:
# 步骤 5: 创建节点和边
logger.info("步骤 5/6: 创建节点和边")
-
+
# 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送
-
+
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
+ perceptual_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
+ perceptual_edges
) = await self._create_nodes_and_edges(dialog_data_list)
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
@@ -273,10 +278,19 @@ class ExtractionOrchestrator:
logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)")
else:
logger.info("步骤 6/6: 两阶段去重和消歧")
-
+
# 注意:deduplication 消息已在创建节点和边完成后立即发送
-
- result = await self._run_dedup_and_write_summary(
+
+ (
+ dialogue_nodes,
+ chunk_nodes,
+ statement_nodes,
+ entity_nodes,
+ statement_chunk_edges,
+ statement_entity_edges,
+ entity_entity_edges,
+ dialog_data_list,
+ ) = await self._run_dedup_and_write_summary(
dialogue_nodes,
chunk_nodes,
statement_nodes,
@@ -287,17 +301,26 @@ class ExtractionOrchestrator:
dialog_data_list,
)
-
-
logger.info(f"知识提取流水线运行完成({mode_str})")
- return result
+ return (
+ dialogue_nodes,
+ chunk_nodes,
+ statement_nodes,
+ entity_nodes,
+ perceptual_nodes,
+ statement_chunk_edges,
+ statement_entity_edges,
+ entity_entity_edges,
+ perceptual_edges,
+ dialog_data_list,
+ )
except Exception as e:
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
raise
async def _extract_statements(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> List[DialogData]:
"""
从对话中提取陈述句(流式输出版本:边提取边发送进度)
@@ -313,7 +336,7 @@ class ExtractionOrchestrator:
# 收集所有分块及其元数据
all_chunks = []
chunk_metadata = [] # (dialog_idx, chunk_idx)
-
+
for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks):
@@ -321,7 +344,7 @@ class ExtractionOrchestrator:
chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
-
+
# 用于跟踪已完成的分块数量
completed_chunks = 0
total_chunks = len(all_chunks)
@@ -332,7 +355,7 @@ class ExtractionOrchestrator:
chunk, end_user_id, dialogue_content = chunk_data
try:
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
-
+
# 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
completed_chunks += 1
@@ -347,11 +370,11 @@ class ExtractionOrchestrator:
"statement_index_in_chunk": idx + 1
}
await self.progress_callback(
- "knowledge_extraction_result",
- f"陈述句提取中 ({completed_chunks}/{total_chunks})",
+ "knowledge_extraction_result",
+ f"陈述句提取中 ({completed_chunks}/{total_chunks})",
stmt_result
)
-
+
return statements
except Exception as e:
logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}")
@@ -381,7 +404,7 @@ class ExtractionOrchestrator:
# 保存陈述句到文件(试运行和正式模式都需要)
self.statement_extractor.save_statements(all_statements)
-
+
logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句")
# 试运行模式下,所有分块提取完成后发送完成事件
@@ -395,7 +418,7 @@ class ExtractionOrchestrator:
return dialog_data_list
async def _extract_triplets(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取三元组(流式输出版本:边提取边发送进度)
@@ -411,7 +434,7 @@ class ExtractionOrchestrator:
# 收集所有陈述句及其元数据
all_statements = []
statement_metadata = [] # (dialog_idx, statement_id, chunk_content)
-
+
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
@@ -419,7 +442,7 @@ class ExtractionOrchestrator:
statement_metadata.append((d_idx, statement.id))
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组")
-
+
# 用于跟踪已完成的陈述句数量
completed_statements = 0
len(all_statements)
@@ -430,11 +453,11 @@ class ExtractionOrchestrator:
statement, chunk_content = stmt_data
try:
triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content)
-
+
# 注意:不再发送三元组提取的流式输出
# 三元组提取在后台执行,但不向前端发送详细信息
completed_statements += 1
-
+
return triplet_info
except Exception as e:
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
@@ -450,7 +473,7 @@ class ExtractionOrchestrator:
# 将结果组织成对话级别的映射
triplet_maps = [{} for _ in dialog_data_list]
all_responses = []
-
+
for i, result in enumerate(results):
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
@@ -478,7 +501,7 @@ class ExtractionOrchestrator:
return triplet_maps
async def _extract_temporal(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取时间信息(流式输出版本:边提取边发送进度)
@@ -502,13 +525,13 @@ class ExtractionOrchestrator:
temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None)
temporal_maps.append(temporal_map)
return temporal_maps
-
+
logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)")
# 收集所有需要提取时间的陈述句
all_statements = []
statement_metadata = [] # (dialog_idx, statement_id, ref_dates)
-
+
for d_idx, dialog in enumerate(dialog_data_list):
# 获取参考日期
ref_dates = {}
@@ -517,11 +540,11 @@ class ExtractionOrchestrator:
ref_dates['conversation_date'] = dialog.metadata['conversation_date']
if 'publication_date' in dialog.metadata:
ref_dates['publication_date'] = dialog.metadata['publication_date']
-
+
if not ref_dates:
from datetime import datetime
ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")}
-
+
for chunk in dialog.chunks:
for statement in chunk.statements:
# 跳过 ATEMPORAL 类型的陈述句
@@ -531,7 +554,7 @@ class ExtractionOrchestrator:
statement_metadata.append((d_idx, statement.id))
logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取")
-
+
# 用于跟踪已完成的时间提取数量
completed_temporal = 0
len(all_statements)
@@ -542,11 +565,11 @@ class ExtractionOrchestrator:
statement, ref_dates = stmt_data
try:
temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
-
+
# 注意:不再发送时间提取的流式输出
# 时间提取在后台执行,但不向前端发送详细信息
completed_temporal += 1
-
+
return temporal_range
except Exception as e:
logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}")
@@ -559,7 +582,7 @@ class ExtractionOrchestrator:
# 将结果组织成对话级别的映射
temporal_maps = [{} for _ in dialog_data_list]
-
+
for i, result in enumerate(results):
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
@@ -585,7 +608,7 @@ class ExtractionOrchestrator:
return temporal_maps
async def _extract_emotions(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
@@ -601,36 +624,36 @@ class ExtractionOrchestrator:
# 收集所有陈述句及其配置
all_statements = []
statement_metadata = [] # (dialog_idx, statement_id)
-
+
# 获取第一个对话的config_id来加载配置
config_id = None
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id
-
+
# 加载MemoryConfig
memory_config = None
if config_id:
try:
from app.db import SessionLocal
from app.repositories.memory_config_repository import MemoryConfigRepository
-
+
db = SessionLocal()
try:
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
finally:
db.close()
-
+
if memory_config and not memory_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list]
-
+
except Exception as e:
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list]
else:
logger.info("未找到config_id,跳过情绪提取")
return [{} for _ in dialog_data_list]
-
+
# 如果配置未启用情绪提取,直接返回空映射
if not memory_config or not memory_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
@@ -639,7 +662,7 @@ class ExtractionOrchestrator:
# 收集所有陈述句(只收集 speaker 为 "user" 的)
total_statements = 0
filtered_statements = 0
-
+
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
@@ -655,12 +678,12 @@ class ExtractionOrchestrator:
# 初始化情绪提取服务
# 如果 emotion_model_id 为空,回退到工作空间默认 LLM
from app.services.emotion_extraction_service import EmotionExtractionService
-
+
emotion_model_id = memory_config.emotion_model_id
if not emotion_model_id and memory_config.workspace_id:
from app.repositories.workspace_repository import get_workspace_models_configs
from app.db import SessionLocal
-
+
db = SessionLocal()
try:
workspace_models = get_workspace_models_configs(db, memory_config.workspace_id)
@@ -669,7 +692,7 @@ class ExtractionOrchestrator:
logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}")
finally:
db.close()
-
+
emotion_service = EmotionExtractionService(
llm_id=emotion_model_id if emotion_model_id else None
)
@@ -689,7 +712,7 @@ class ExtractionOrchestrator:
# 将结果组织成对话级别的映射
emotion_maps = [{} for _ in dialog_data_list]
successful_extractions = 0
-
+
for i, result in enumerate(results):
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
@@ -706,7 +729,7 @@ class ExtractionOrchestrator:
return emotion_maps
async def _parallel_extract_and_embed(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> Tuple[
List[Dict[str, Any]],
List[Dict[str, Any]],
@@ -757,7 +780,7 @@ class ExtractionOrchestrator:
triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list]
temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list]
emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list]
-
+
if isinstance(results[3], Exception):
logger.error(f"基础嵌入生成失败: {results[3]}")
statement_embedding_maps = [{} for _ in dialog_data_list]
@@ -777,7 +800,7 @@ class ExtractionOrchestrator:
)
async def _generate_basic_embeddings(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
"""
生成基础嵌入向量(陈述句、分块、对话)
@@ -810,7 +833,7 @@ class ExtractionOrchestrator:
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
raise ValueError("embedding_id is required but was not provided")
-
+
# 只生成陈述句、分块和对话的嵌入(不包括实体)
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation(
dialog_data_list, self.embedding_id
@@ -836,7 +859,7 @@ class ExtractionOrchestrator:
)
async def _generate_entity_embeddings(
- self, triplet_maps: List[Dict[str, Any]]
+ self, triplet_maps: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
生成实体嵌入向量
@@ -861,7 +884,7 @@ class ExtractionOrchestrator:
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
return triplet_maps
-
+
# 生成实体嵌入
updated_triplet_maps = await generate_entity_embeddings_from_triplets(
triplet_maps, self.embedding_id
@@ -874,17 +897,15 @@ class ExtractionOrchestrator:
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
return triplet_maps
-
-
async def _assign_extracted_data(
- self,
- dialog_data_list: List[DialogData],
- temporal_maps: List[Dict[str, Any]],
- triplet_maps: List[Dict[str, Any]],
- emotion_maps: List[Dict[str, Any]],
- statement_embedding_maps: List[Dict[str, List[float]]],
- chunk_embedding_maps: List[Dict[str, List[float]]],
- dialog_embeddings: List[List[float]],
+ self,
+ dialog_data_list: List[DialogData],
+ temporal_maps: List[Dict[str, Any]],
+ triplet_maps: List[Dict[str, Any]],
+ emotion_maps: List[Dict[str, Any]],
+ statement_embedding_maps: List[Dict[str, List[float]]],
+ chunk_embedding_maps: List[Dict[str, List[float]]],
+ dialog_embeddings: List[List[float]],
) -> List[DialogData]:
"""
将提取的数据赋值到语句
@@ -906,12 +927,12 @@ class ExtractionOrchestrator:
# 确保列表长度匹配
expected_length = len(dialog_data_list)
if (
- len(temporal_maps) != expected_length
- or len(triplet_maps) != expected_length
- or len(emotion_maps) != expected_length
- or len(statement_embedding_maps) != expected_length
- or len(chunk_embedding_maps) != expected_length
- or len(dialog_embeddings) != expected_length
+ len(temporal_maps) != expected_length
+ or len(triplet_maps) != expected_length
+ or len(emotion_maps) != expected_length
+ or len(statement_embedding_maps) != expected_length
+ or len(chunk_embedding_maps) != expected_length
+ or len(dialog_embeddings) != expected_length
):
logger.warning(
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
@@ -999,15 +1020,17 @@ class ExtractionOrchestrator:
return dialog_data_list
async def _create_nodes_and_edges(
- self, dialog_data_list: List[DialogData]
+ self, dialog_data_list: List[DialogData]
) -> Tuple[
List[DialogueNode],
List[ChunkNode],
List[StatementNode],
List[ExtractedEntityNode],
+ List[PerceptualNode],
List[StatementChunkEdge],
List[StatementEntityEdge],
List[EntityEntityEdge],
+ List[PerceptualEdge]
]:
"""
创建图数据库节点和边
@@ -1021,7 +1044,7 @@ class ExtractionOrchestrator:
包含所有节点和边的元组
"""
logger.info("开始创建节点和边")
-
+
# 注意:开始消息已在 run 方法中发送,这里不再重复发送
dialogue_nodes = []
@@ -1031,10 +1054,12 @@ class ExtractionOrchestrator:
statement_chunk_edges = []
statement_entity_edges = []
entity_entity_edges = []
+ perceptual_nodes = []
+ perceptual_edges = []
# 用于去重的集合
entity_id_set = set()
-
+
# 用于跟踪进度
total_dialogs = len(dialog_data_list)
processed_dialogs = 0
@@ -1075,6 +1100,45 @@ class ExtractionOrchestrator:
)
chunk_nodes.append(chunk_node)
+ for p, file_type in chunk.files:
+
+ meta = p.meta_data or {}
+ content_meta = meta.get("content", {})
+
+ # 生成 summary embedding(如果有 embedder_client)
+ summary_embedding = None
+ if self.embedder_client and p.summary:
+ try:
+ summary_embedding = (await self.embedder_client.response([p.summary]))[0]
+ except Exception as emb_err:
+ print(f"Failed to embed perceptual summary: {emb_err}")
+
+ perceptual = PerceptualNode(
+ name=f"Perceptual_{p.id}",
+ **{
+ "id": str(p.id),
+ "end_user_id": str(p.end_user_id),
+ "perceptual_type": p.perceptual_type,
+ "file_path": p.file_path or "",
+ "file_name": p.file_name or "",
+ "file_ext": p.file_ext or "",
+ "summary": p.summary or "",
+ "keywords": content_meta.get("keywords", []),
+ "topic": content_meta.get("topic", ""),
+ "domain": content_meta.get("domain", ""),
+ "created_at": p.created_time.isoformat() if p.created_time else None,
+ "file_type": file_type,
+ "summary_embedding": summary_embedding,
+ })
+ perceptual_nodes.append(perceptual)
+ perceptual_edges.append(PerceptualEdge(
+ source=perceptual.id,
+ target=chunk.id,
+ end_user_id=dialog_data.end_user_id,
+ run_id=dialog_data.run_id,
+ created_at=dialog_data.created_at,
+ ))
+
# 处理每个陈述句
for statement in chunk.statements:
# 创建陈述句节点
@@ -1083,15 +1147,19 @@ class ExtractionOrchestrator:
name=f"Statement_{statement.id}", # 添加必需的 name 字段
chunk_id=chunk.id,
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
- temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
- connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
+ temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
+ # 添加必需的 temporal_info 字段
+ connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
+ # 添加必需的 connect_strength 字段
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding,
- valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
- invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
+ valid_at=statement.temporal_validity.valid_at if hasattr(statement,
+ 'temporal_validity') and statement.temporal_validity else None,
+ invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
+ 'temporal_validity') and statement.temporal_validity else None,
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
@@ -1120,7 +1188,7 @@ class ExtractionOrchestrator:
# 创建实体索引到ID的映射(支持多种索引方式)
entity_idx_to_id = {}
-
+
# 创建实体节点
for entity_idx, entity in enumerate(triplet_info.entities):
# 映射实体索引到实体ID(使用多个键以提高容错性)
@@ -1128,7 +1196,7 @@ class ExtractionOrchestrator:
entity_idx_to_id[entity.entity_idx] = entity.id
# 2. 使用枚举索引(从0开始)
entity_idx_to_id[entity_idx] = entity.id
-
+
if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
entity_node = ExtractedEntityNode(
@@ -1141,7 +1209,8 @@ class ExtractionOrchestrator:
example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
- connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
+ connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
+ # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
@@ -1171,7 +1240,7 @@ class ExtractionOrchestrator:
# 将三元组中的整数索引映射到实体ID
subject_entity_id = entity_idx_to_id.get(triplet.subject_id)
object_entity_id = entity_idx_to_id.get(triplet.object_id)
-
+
# 只有当两个实体ID都存在时才创建边
if subject_entity_id and object_entity_id:
entity_entity_edge = EntityEntityEdge(
@@ -1186,7 +1255,7 @@ class ExtractionOrchestrator:
expired_at=dialog_data.expired_at,
)
entity_entity_edges.append(entity_entity_edge)
-
+
# 流式输出:每创建一个关系边,立即发送进度(限制发送数量)
if self.progress_callback and len(entity_entity_edges) <= 10:
# 获取实体名称
@@ -1202,8 +1271,8 @@ class ExtractionOrchestrator:
"dialog_progress": f"{processed_dialogs}/{total_dialogs}"
}
await self.progress_callback(
- "creating_nodes_edges_result",
- f"关系创建中 ({processed_dialogs}/{total_dialogs})",
+ "creating_nodes_edges_result",
+ f"关系创建中 ({processed_dialogs}/{total_dialogs})",
relationship_result
)
else:
@@ -1211,7 +1280,7 @@ class ExtractionOrchestrator:
missing_subject = "subject" if not subject_entity_id else ""
missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
-
+
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
@@ -1228,7 +1297,7 @@ class ExtractionOrchestrator:
f"陈述句-实体边: {len(statement_entity_edges)}, "
f"实体-实体边: {len(entity_entity_edges)}"
)
-
+
# 进度回调:创建节点和边完成,传递结果统计
# 注意:具体的关系创建结果已经在创建过程中实时发送了
if self.progress_callback:
@@ -1248,25 +1317,32 @@ class ExtractionOrchestrator:
chunk_nodes,
statement_nodes,
entity_nodes,
+ perceptual_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
+ perceptual_edges
)
async def _run_dedup_and_write_summary(
- self,
- dialogue_nodes: List[DialogueNode],
- chunk_nodes: List[ChunkNode],
- statement_nodes: List[StatementNode],
- entity_nodes: List[ExtractedEntityNode],
- statement_chunk_edges: List[StatementChunkEdge],
- statement_entity_edges: List[StatementEntityEdge],
- entity_entity_edges: List[EntityEntityEdge],
- dialog_data_list: List[DialogData],
- ) -> Tuple[
- Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
- Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
- Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
+ self,
+ dialogue_nodes: List[DialogueNode],
+ chunk_nodes: List[ChunkNode],
+ statement_nodes: List[StatementNode],
+ entity_nodes: List[ExtractedEntityNode],
+ statement_chunk_edges: List[StatementChunkEdge],
+ statement_entity_edges: List[StatementEntityEdge],
+ entity_entity_edges: List[EntityEntityEdge],
+ dialog_data_list: List[DialogData],
+ ) -> tuple[
+ list[DialogueNode],
+ list[ChunkNode],
+ list[StatementNode],
+ list[ExtractedEntityNode],
+ list[StatementChunkEdge],
+ list[StatementEntityEdge],
+ list[EntityEntityEdge],
+ dict
]:
"""
执行两阶段去重并写入汇总
@@ -1288,11 +1364,11 @@ class ExtractionOrchestrator:
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
"""
logger.info("开始两阶段实体去重和消歧")
-
+
# 进度回调:发送去重消歧开始消息
if self.progress_callback:
await self.progress_callback("deduplication", "正在去重消歧...")
-
+
logger.info(
f"去重前: {len(entity_nodes)} 个实体节点, "
f"{len(statement_entity_edges)} 条陈述句-实体边, "
@@ -1307,7 +1383,7 @@ class ExtractionOrchestrator:
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
-
+
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
@@ -1317,10 +1393,10 @@ class ExtractionOrchestrator:
dedup_config=self.config.deduplication,
llm_client=self.llm_client,
)
-
+
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
-
+
result_tuple = (
dialogue_nodes,
chunk_nodes,
@@ -1330,7 +1406,7 @@ class ExtractionOrchestrator:
dedup_statement_entity_edges,
dedup_entity_entity_edges,
)
-
+
final_entity_nodes = dedup_entity_nodes
final_statement_entity_edges = dedup_statement_entity_edges
final_entity_entity_edges = dedup_entity_entity_edges
@@ -1361,7 +1437,7 @@ class ExtractionOrchestrator:
final_entity_entity_edges,
dedup_details,
) = result_tuple
-
+
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
@@ -1375,12 +1451,12 @@ class ExtractionOrchestrator:
f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, "
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
)
-
+
# 流式输出:实时输出去重消歧的具体结果
if self.progress_callback:
# 分析实体合并情况(使用内存中的记录)
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
-
+
# 逐个输出去重合并的实体示例
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
dedup_result = {
@@ -1391,10 +1467,10 @@ class ExtractionOrchestrator:
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
}
await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
-
+
# 分析实体消歧情况(使用内存中的记录)
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
-
+
# 逐个输出实体消歧的结果
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
disamb_result = {
@@ -1407,14 +1483,13 @@ class ExtractionOrchestrator:
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
}
await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
-
+
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
await self._send_dedup_progress_callback(
len(entity_nodes), len(final_entity_nodes),
len(statement_entity_edges), len(final_statement_entity_edges),
len(entity_entity_edges), len(final_entity_entity_edges)
)
-
# 写入提取结果汇总(试运行和正式模式都需要生成)
try:
@@ -1436,10 +1511,10 @@ class ExtractionOrchestrator:
raise
def _save_dedup_details(
- self,
- dedup_details: Dict[str, Any],
- original_entities: List[ExtractedEntityNode],
- final_entities: List[ExtractedEntityNode]
+ self,
+ dedup_details: Dict[str, Any],
+ original_entities: List[ExtractedEntityNode],
+ final_entities: List[ExtractedEntityNode]
):
"""
保存去重消歧的详细记录到实例变量(基于内存数据结构)
@@ -1452,7 +1527,7 @@ class ExtractionOrchestrator:
try:
# 保存ID重定向映射
self.id_redirect_map = dedup_details.get("id_redirect", {})
-
+
# 处理精确匹配的合并记录
exact_merge_map = dedup_details.get("exact_merge_map", {})
for key, info in exact_merge_map.items():
@@ -1466,7 +1541,7 @@ class ExtractionOrchestrator:
"merged_count": len(merged_ids),
"merged_ids": list(merged_ids)
})
-
+
# 处理模糊匹配的合并记录
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
for record in fuzzy_merge_records:
@@ -1486,7 +1561,7 @@ class ExtractionOrchestrator:
})
except Exception as e:
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
-
+
# 处理LLM去重的合并记录
llm_decision_records = dedup_details.get("llm_decision_records", [])
for record in llm_decision_records:
@@ -1505,7 +1580,7 @@ class ExtractionOrchestrator:
})
except Exception as e:
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
-
+
# 处理消歧记录
disamb_records = dedup_details.get("disamb_records", [])
for record in disamb_records:
@@ -1520,14 +1595,14 @@ class ExtractionOrchestrator:
entity1_type = match.group(2)
match.group(3).strip()
entity2_type = match.group(4)
-
+
# 提取置信度和原因
conf_match = re.search(r"conf=([0-9.]+)", str(record))
confidence = conf_match.group(1) if conf_match else "unknown"
-
+
reason_match = re.search(r"reason=([^|]+)", str(record))
reason = reason_match.group(1).strip() if reason_match else ""
-
+
self.dedup_disamb_records.append({
"entity_name": entity1_name,
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
@@ -1536,16 +1611,17 @@ class ExtractionOrchestrator:
})
except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
-
- logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
-
+
+ logger.info(
+ f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
+
except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges(
- self,
- original_entities: List[ExtractedEntityNode],
- final_entities: List[ExtractedEntityNode]
+ self,
+ original_entities: List[ExtractedEntityNode],
+ final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
@@ -1566,28 +1642,28 @@ class ExtractionOrchestrator:
key=lambda x: x.get("merged_count", 0),
reverse=True
)
-
+
merge_info = []
for record in sorted_records:
merge_info.append({
"main_entity_name": record.get("entity_name", "未知实体"),
"merged_count": record.get("merged_count", 1)
})
-
+
return merge_info
-
+
# 如果没有保存的记录,返回空列表
logger.info("未找到实体合并记录")
return []
-
+
except Exception as e:
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
return []
async def _analyze_entity_disambiguation(
- self,
- original_entities: List[ExtractedEntityNode],
- final_entities: List[ExtractedEntityNode]
+ self,
+ original_entities: List[ExtractedEntityNode],
+ final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
@@ -1603,11 +1679,11 @@ class ExtractionOrchestrator:
# 直接使用保存的消歧记录
if self.dedup_disamb_records:
return self.dedup_disamb_records
-
+
# 如果没有保存的记录,返回空列表
logger.info("未找到实体消歧记录")
return []
-
+
except Exception as e:
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
return []
@@ -1624,7 +1700,7 @@ class ExtractionOrchestrator:
"""
type_mapping = {
"Person": "人物实体节点",
- "Organization": "组织实体节点",
+ "Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
@@ -1645,9 +1721,9 @@ class ExtractionOrchestrator:
return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results(
- self,
- entity_entity_edges: List[EntityEntityEdge],
- entity_nodes: List[ExtractedEntityNode]
+ self,
+ entity_entity_edges: List[EntityEntityEdge],
+ entity_nodes: List[ExtractedEntityNode]
):
"""
输出关系创建结果
@@ -1659,13 +1735,13 @@ class ExtractionOrchestrator:
try:
# 创建实体ID到名称的映射
entity_id_to_name = {node.id: node.name for node in entity_nodes}
-
+
# 输出关系创建结果
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
relation_type = edge.relation_type
-
+
relationship_result = {
"result_type": "relationship_creation",
"relationship_index": i + 1,
@@ -1674,20 +1750,20 @@ class ExtractionOrchestrator:
"target_entity": target_name,
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
}
-
+
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
-
+
except Exception as e:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback(
- self,
- original_entities: int,
- final_entities: int,
- original_stmt_edges: int,
- final_stmt_edges: int,
- original_ent_edges: int,
- final_ent_edges: int,
+ self,
+ original_entities: int,
+ final_entities: int,
+ original_stmt_edges: int,
+ final_stmt_edges: int,
+ original_ent_edges: int,
+ final_ent_edges: int,
):
"""
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
@@ -1703,19 +1779,20 @@ class ExtractionOrchestrator:
try:
# 解析去重消歧报告文件,获取具体的去重和消歧效果
dedup_details = await self._parse_dedup_report()
-
+
# 计算去重效果统计
entities_reduced = original_entities - final_entities
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
ent_edges_reduced = original_ent_edges - final_ent_edges
-
+
# 构建进度回调数据
dedup_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": entities_reduced,
- "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
+ "reduction_rate": round(entities_reduced / original_entities * 100,
+ 1) if original_entities > 0 else 0,
},
"statement_entity_edges": {
"original_count": original_stmt_edges,
@@ -1734,9 +1811,9 @@ class ExtractionOrchestrator:
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
}
}
-
+
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
-
+
except Exception as e:
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
# 即使解析失败,也发送基本的统计信息
@@ -1766,12 +1843,12 @@ class ExtractionOrchestrator:
disamb_examples = []
total_merges = 0
total_disambiguations = 0
-
+
# 处理合并记录
for record in self.dedup_merge_records:
merge_count = record.get("merged_count", 0)
total_merges += merge_count
-
+
dedup_examples.append({
"type": record.get("type", "未知"),
"entity_name": record.get("entity_name", "未知实体"),
@@ -1779,30 +1856,31 @@ class ExtractionOrchestrator:
"merge_count": merge_count,
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个"
})
-
+
# 处理消歧记录
for record in self.dedup_disamb_records:
total_disambiguations += 1
-
+
# 从消歧类型中提取实体类型信息
disamb_type = record.get("disamb_type", "")
entity_name = record.get("entity_name", "未知实体")
-
+
disamb_examples.append({
"entity1_name": entity_name,
- "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
+ "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
+ "").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功"
})
-
+
return {
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
"total_merges": total_merges,
"total_disambiguations": total_disambiguations,
}
-
+
except Exception as e:
logger.error(f"获取去重报告失败: {e}", exc_info=True)
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
@@ -1815,9 +1893,9 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs(
- chunker_strategy: str = "RecursiveChunker",
- end_user_id: str = "group_1",
- indices: Optional[List[int]] = None,
+ chunker_strategy: str = "RecursiveChunker",
+ end_user_id: str = "group_1",
+ indices: Optional[List[int]] = None,
) -> List[DialogData]:
"""从测试数据生成分块对话
@@ -1831,7 +1909,7 @@ async def get_chunked_dialogs(
"""
import json
import re
-
+
# 加载测试数据
testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json")
with open(testdata_path, "r", encoding="utf-8") as f:
@@ -1845,7 +1923,7 @@ async def get_chunked_dialogs(
else:
# 默认使用所有数据
selected_data = test_data
-
+
for data in selected_data:
# 解析对话上下文
context_text = data["context"]
@@ -1861,7 +1939,7 @@ async def get_chunked_dialogs(
if m:
y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3))
conv_date = f"{y:04d}-{mo:02d}-{d:02d}"
-
+
dialog_metadata: Dict[str, Any] = {}
if conv_date:
dialog_metadata["conversation_date"] = conv_date
@@ -1890,7 +1968,7 @@ async def get_chunked_dialogs(
end_user_id=end_user_id,
metadata=dialog_metadata,
)
-
+
# 创建分块器并处理对话
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
@@ -1913,7 +1991,7 @@ async def get_chunked_dialogs(
from app.core.config import settings
settings.ensure_memory_output_dir()
output_path = settings.get_memory_output_path("chunker_test_output.txt")
-
+
import json
with open(output_path, "w", encoding="utf-8") as f:
json.dump(
@@ -1924,10 +2002,10 @@ async def get_chunked_dialogs(
def preprocess_data(
- input_path: Optional[str] = None,
- output_path: Optional[str] = None,
- skip_cleaning: bool = True,
- indices: Optional[List[int]] = None
+ input_path: Optional[str] = None,
+ output_path: Optional[str] = None,
+ skip_cleaning: bool = True,
+ indices: Optional[List[int]] = None
) -> List[DialogData]:
"""数据预处理
@@ -1946,7 +2024,8 @@ def preprocess_data(
)
preprocessor = DataPreprocessor()
try:
- cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
+ cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
+ skip_cleaning=skip_cleaning, indices=indices)
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data
except Exception as e:
@@ -1955,9 +2034,9 @@ def preprocess_data(
async def get_chunked_dialogs_from_preprocessed(
- data: List[DialogData],
- chunker_strategy: str = "RecursiveChunker",
- llm_client: Optional[Any] = None,
+ data: List[DialogData],
+ chunker_strategy: str = "RecursiveChunker",
+ llm_client: Optional[Any] = None,
) -> List[DialogData]:
"""从预处理后的数据中生成分块
@@ -1972,31 +2051,31 @@ async def get_chunked_dialogs_from_preprocessed(
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
if not data:
raise ValueError("预处理数据为空,无法进行分块")
-
+
all_chunked_dialogs: List[DialogData] = []
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
-
+
for dialog_data in data:
chunker = DialogueChunker(chunker_strategy, llm_client=llm_client)
chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = chunks
all_chunked_dialogs.append(dialog_data)
-
+
return all_chunked_dialogs
async def get_chunked_dialogs_with_preprocessing(
- chunker_strategy: str = "RecursiveChunker",
- end_user_id: str = "default",
- user_id: str = "default",
- apply_id: str = "default",
- indices: Optional[List[int]] = None,
- input_data_path: Optional[str] = None,
- llm_client: Optional[Any] = None,
- skip_cleaning: bool = True,
- pruning_config: Optional[Dict] = None,
+ chunker_strategy: str = "RecursiveChunker",
+ end_user_id: str = "default",
+ user_id: str = "default",
+ apply_id: str = "default",
+ indices: Optional[List[int]] = None,
+ input_data_path: Optional[str] = None,
+ llm_client: Optional[Any] = None,
+ skip_cleaning: bool = True,
+ pruning_config: Optional[Dict] = None,
) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程
@@ -2020,7 +2099,7 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path = os.path.join(
os.path.dirname(__file__), "../../data", "testdata.json"
)
-
+
# 步骤1: 数据预处理(包含索引筛选)
from app.core.config import settings
settings.ensure_memory_output_dir()
@@ -2030,37 +2109,38 @@ async def get_chunked_dialogs_with_preprocessing(
skip_cleaning=skip_cleaning,
indices=indices,
)
-
+
# 设置 end_user_id
for dd in preprocessed_data:
dd.end_user_id = end_user_id
-
+
# 步骤2: 语义剪枝
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
-
+
# 构建剪枝配置
if pruning_config:
# 使用传入的配置
config = PruningConfig(**pruning_config)
- logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
+ logger.debug(
+ f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else:
# 使用默认配置(关闭剪枝)
config = None
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
-
+
pruner = SemanticPruner(config=config, llm_client=llm_client)
-
+
# 记录单对话场景下剪枝前的消息数量
single_dialog_original_msgs = None
if len(preprocessed_data) == 1 and preprocessed_data[0].context:
single_dialog_original_msgs = len(preprocessed_data[0].context.msgs)
preprocessed_data = await pruner.prune_dataset(preprocessed_data)
-
+
# 单对话:打印清洗与剪枝信息
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
@@ -2071,7 +2151,7 @@ async def get_chunked_dialogs_with_preprocessing(
)
else:
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
-
+
# 保存剪枝后的数据
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
@@ -2084,7 +2164,7 @@ async def get_chunked_dialogs_with_preprocessing(
logger.error(f"保存剪枝结果失败:{se}")
except Exception as e:
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
-
+
# 步骤3: 对话分块
return await get_chunked_dialogs_from_preprocessed(
preprocessed_data,
diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py
index 72f3641e..33838061 100644
--- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py
+++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py
@@ -5,8 +5,11 @@
"""
import asyncio
+import logging
from typing import Any, Dict, List, Tuple
+logger = logging.getLogger(__name__)
+
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
return await self.embedder_client.response(texts)
# 分批并行处理
- print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
+ logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
- print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
+ logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
# 并行发送所有批次
batch_results = await asyncio.gather(*[
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
for batch_result in batch_results:
embeddings.extend(batch_result)
- print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
+ logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
return embeddings
async def generate_statement_embeddings(
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
Returns:
每个对话的陈述句嵌入向量映射列表
"""
- print("\n=== 生成陈述句嵌入向量 ===")
+ logger.debug("=== 生成陈述句嵌入向量 ===")
# 收集所有陈述句
all_statements = []
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
stmt_embedding_maps[d_idx][stmt_id] = embedding
- print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
+ logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
return stmt_embedding_maps
async def generate_chunk_embeddings(
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
Returns:
每个对话的分块嵌入向量映射列表
"""
- print("\n=== 生成分块嵌入向量 ===")
+ logger.debug("=== 生成分块嵌入向量 ===")
# 收集所有分块
all_chunks = []
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
chunk_embedding_maps[d_idx][chunk_id] = embedding
- print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
+ logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
return chunk_embedding_maps
async def generate_dialog_embeddings(
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
"""
- print("\n=== 生成所有嵌入向量 ===")
+ logger.debug("=== 生成所有嵌入向量 ===")
# 并发生成陈述句和分块嵌入向量
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
# 对话嵌入向量(当前跳过)
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
- print(
- f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
- )
+ logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
Returns:
更新后的三元组映射列表(实体包含嵌入向量)
"""
- print("\n=== 生成实体嵌入向量 ===")
+ logger.debug("=== 生成实体嵌入向量 ===")
entity_texts: List[str] = []
entity_refs: List[Any] = []
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
entity_refs.append(ent)
if not entity_texts:
- print("没有找到需要生成嵌入向量的实体")
+ logger.debug("没有找到需要生成嵌入向量的实体")
return triplet_maps
# 批量生成嵌入向量
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))):
- print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
+ logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings):
setattr(ent, "name_embedding", emb)
- print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
+ logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
return triplet_maps
@@ -296,7 +297,7 @@ async def embedding_generation_all(
Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
"""
- print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
+ logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
generator = EmbeddingGenerator(embedding_id)
diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py
index 443ee36a..5e39ba36 100644
--- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py
+++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
response_model=MemorySummaryResponse,
)
summary_text = structured.summary.strip()
-
# Generate title and type for the summary
title = None
episodic_type = None
diff --git a/api/app/core/models/__init__.py b/api/app/core/models/__init__.py
index f54afc08..f98d073f 100644
--- a/api/app/core/models/__init__.py
+++ b/api/app/core/models/__init__.py
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
from .llm import RedBearLLM
from .embedding import RedBearEmbeddings
from .rerank import RedBearRerank
+from .generation import RedBearImageGenerator, RedBearVideoGenerator
__all__ = [
"RedBearModelConfig",
@@ -9,5 +10,7 @@ __all__ = [
"RedBearEmbeddings",
"RedBearRerank",
"RedBearModelFactory",
- "get_provider_llm_class"
+ "get_provider_llm_class",
+ "RedBearImageGenerator",
+ "RedBearVideoGenerator"
]
\ No newline at end of file
diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py
index 4a453c6b..80117f27 100644
--- a/api/app/core/models/base.py
+++ b/api/app/core/models/base.py
@@ -67,7 +67,7 @@ class RedBearModelFactory:
**config.extra_params
}
- if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
+ if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
# 这样可以分别控制连接超时和读取超时
import httpx
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI
- if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
+ if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if type == ModelType.LLM:
return OpenAI
elif type == ModelType.CHAT:
return ChatOpenAI
+ else:
+ raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py
index 16af2567..3269e1d0 100644
--- a/api/app/core/models/embedding.py
+++ b/api/app/core/models/embedding.py
@@ -1,23 +1,190 @@
-from typing import Any, Dict, List, Optional, TypeVar, Callable
+from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings
-from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
+from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
+from app.models.models_model import ModelProvider
+
class RedBearEmbeddings(Embeddings):
- """Embedding → 完全符合 LangChain Embeddings"""
+ """统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
+
def __init__(self, config: RedBearModelConfig):
- self._model = self._create_model(config)
self._config = config
+ self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
+
+ if self._is_volcano:
+ # 火山引擎使用 Ark SDK
+ self._client = self._create_volcano_client(config)
+ self._model = None
+ else:
+ # 其他 provider 使用 LangChain
+ self._model = self._create_model(config)
+ self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
- """根据配置创建模型"""
+ """根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params)
+
+ def _create_volcano_client(self, config: RedBearModelConfig):
+ """创建火山引擎客户端"""
+ from volcenginesdkarkruntime import Ark
+ return Ark(api_key=config.api_key, base_url=config.base_url)
+ # ==================== LangChain 标准接口 ====================
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
- return self._model.embed_documents(texts)
+ """批量文本向量化(LangChain 标准接口)"""
+ if self._is_volcano:
+ # 火山引擎多模态 Embedding
+ contents = [{"type": "text", "text": text} for text in texts]
+ response = self._client.multimodal_embeddings.create(
+ model=self._config.model_name,
+ input=contents,
+ encoding_format="float"
+ )
+ return [response.data.embedding]
+ else:
+ # 其他 provider
+ return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
- return self._model.embed_query(text)
+ """单个文本向量化(LangChain 标准接口)"""
+ if self._is_volcano:
+ # 火山引擎多模态 Embedding
+ result = self.embed_documents([text])
+ return result[0] if result else []
+ else:
+ # 其他 provider
+ return self._model.embed_query(text)
+
+ # ==================== 多模态扩展方法 ====================
+
+ def embed_multimodal(
+ self,
+ contents: List[Dict[str, Any]],
+ **kwargs
+ ) -> List[List[float]]:
+ """
+ 多模态向量化(仅火山引擎支持)
+
+ Args:
+ contents: 内容列表,格式:
+ - 文本: {"type": "text", "text": "..."}
+ - 图片: {"type": "image_url", "image_url": {"url": "..."}}
+ - 视频: {"type": "video_url", "video_url": {"url": "..."}}
+ **kwargs: 其他参数
+
+ Returns:
+ 向量列表
+ """
+ if not self._is_volcano:
+ raise NotImplementedError(
+ f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
+ )
+
+ response = self._client.multimodal_embeddings.create(
+ model=self._config.model_name,
+ input=contents,
+ **kwargs
+ )
+ return [response.data.embedding]
+
+ async def aembed_multimodal(
+ self,
+ contents: List[Dict[str, Any]],
+ **kwargs
+ ) -> List[List[float]]:
+ """异步多模态向量化"""
+ # 火山引擎 SDK 暂不支持异步,使用同步方法
+ return self.embed_multimodal(contents, **kwargs)
+
+ def embed_text(self, text: str, **kwargs) -> List[float]:
+ """文本向量化(便捷方法)"""
+ if self._is_volcano:
+ result = self.embed_multimodal(
+ [{"type": "text", "text": text}],
+ **kwargs
+ )
+ return result[0] if result else []
+ else:
+ return self.embed_query(text)
+
+ def embed_image(self, image_url: str, **kwargs) -> List[float]:
+ """图片向量化(仅火山引擎支持)"""
+ if not self._is_volcano:
+ raise NotImplementedError(
+ f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
+ )
+
+ result = self.embed_multimodal(
+ [{"type": "image_url", "image_url": {"url": image_url}}],
+ **kwargs
+ )
+ return result[0] if result else []
+
+ def embed_video(self, video_url: str, **kwargs) -> List[float]:
+ """视频向量化(仅火山引擎支持)"""
+ if not self._is_volcano:
+ raise NotImplementedError(
+ f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
+ )
+
+ result = self.embed_multimodal(
+ [{"type": "video_url", "video_url": {"url": video_url}}],
+ **kwargs
+ )
+ return result[0] if result else []
+
+ def embed_batch(
+ self,
+ items: List[Union[str, Dict[str, Any]]],
+ **kwargs
+ ) -> List[List[float]]:
+ """
+ 批量向量化(支持混合类型)
+
+ Args:
+ items: 可以是字符串列表或内容字典列表
+ **kwargs: 其他参数
+
+ Returns:
+ 向量列表
+ """
+ # 如果全是字符串,使用标准方法
+ if all(isinstance(item, str) for item in items):
+ return self.embed_documents(items)
+
+ # 如果包含字典,需要多模态支持
+ if not self._is_volcano:
+ raise NotImplementedError(
+ f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
+ )
+
+ # 标准化输入格式
+ contents = []
+ for item in items:
+ if isinstance(item, str):
+ contents.append({"type": "text", "text": item})
+ elif isinstance(item, dict):
+ contents.append(item)
+ else:
+ raise ValueError(f"不支持的输入类型: {type(item)}")
+
+ return self.embed_multimodal(contents, **kwargs)
+
+ # ==================== 工具方法 ====================
+
+ def is_multimodal_supported(self) -> bool:
+ """检查是否支持多模态"""
+ return self._is_volcano
+
+ def get_provider(self) -> str:
+ """获取 provider"""
+ return self._config.provider
+
+
+# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
+RedBearMultimodalEmbeddings = RedBearEmbeddings
diff --git a/api/app/core/models/generation.py b/api/app/core/models/generation.py
new file mode 100644
index 00000000..b6388d3f
--- /dev/null
+++ b/api/app/core/models/generation.py
@@ -0,0 +1,344 @@
+"""
+图片和视频生成模型封装
+
+支持的 Provider:
+- Volcano (火山引擎): 使用 volcenginesdkarkruntime
+- OpenAI: 使用 openai SDK
+"""
+from typing import Any, Dict, Optional
+
+from volcenginesdkarkruntime import Ark
+from volcenginesdkarkruntime.types.images.images import (
+ SequentialImageGenerationOptions,
+ ContentGenerationTool,
+ OptimizePromptOptions
+)
+
+from app.core.models.base import RedBearModelConfig
+from app.core.exceptions import BusinessException
+from app.core.error_codes import BizCode
+from app.models.models_model import ModelProvider
+
+
+class RedBearImageGenerator:
+ """图片生成模型封装"""
+
+ def __init__(self, config: RedBearModelConfig):
+ self._config = config
+ self._client = self._create_client(config)
+
+ def _create_client(self, config: RedBearModelConfig):
+ """根据 provider 创建客户端"""
+ provider = config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ return Ark(api_key=config.api_key, base_url=config.base_url)
+ # elif provider == ModelProvider.OPENAI:
+ # from openai import OpenAI
+ # return OpenAI(api_key=config.api_key, base_url=config.base_url)
+ else:
+ raise BusinessException(
+ f"不支持的图片生成提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ def generate(
+ self,
+ prompt: str,
+ image: Optional[Any] = None,
+ size: Optional[str] = "2K",
+ output_format: str = "png",
+ response_format: str = "url",
+ watermark: bool = False,
+ sequential_image_generation: Optional[str] = None,
+ sequential_image_generation_options: Optional[Dict] = None,
+ tools: Optional[list] = None,
+ optimize_prompt_options: Optional[Dict] = None,
+ stream: bool = False,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """
+ 生成图片
+
+ Args:
+ prompt: 提示词
+ image: 参考图片URL或URL列表(图文生图/多图融合)
+ size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素)
+ output_format: 输出格式,如 "png", "jpg"
+ response_format: 返回格式,"url" 或 "b64_json"
+ watermark: 是否添加水印
+ sequential_image_generation: 组图生成模式,"auto" 或 "disabled"
+ sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
+ tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
+ optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
+ stream: 是否使用流式生成
+ **kwargs: 其他参数
+
+ Returns:
+ 生成结果
+ """
+ provider = self._config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ params = {
+ "model": self._config.model_name,
+ "prompt": prompt,
+ "size": size,
+ "output_format": output_format,
+ "response_format": response_format,
+ "watermark": watermark,
+ }
+
+ if image is not None:
+ params["image"] = image
+
+ if sequential_image_generation:
+ params["sequential_image_generation"] = sequential_image_generation
+ if sequential_image_generation_options:
+ params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
+ **sequential_image_generation_options
+ )
+
+ if tools:
+ params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
+
+ if optimize_prompt_options:
+ params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
+
+ if stream:
+ params["stream"] = True
+
+ params.update(kwargs)
+ response = self._client.images.generate(**params)
+
+ # elif provider == ModelProvider.OPENAI:
+ # response = self._client.images.generate(
+ # model=self._config.model_name,
+ # prompt=prompt,
+ # size=size,
+ # n=n,
+ # **kwargs
+ # )
+ else:
+ raise BusinessException(
+ f"不支持的提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ return response.model_dump() if hasattr(response, 'model_dump') else response
+
+ async def agenerate(
+ self,
+ prompt: str,
+ image: Optional[Any] = None,
+ size: Optional[str] = "2K",
+ output_format: str = "png",
+ response_format: str = "url",
+ watermark: bool = False,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """异步生成图片"""
+ return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
+
+
+class RedBearVideoGenerator:
+ """视频生成模型封装"""
+
+ def __init__(self, config: RedBearModelConfig):
+ self._config = config
+ self._client = self._create_client(config)
+
+ def _create_client(self, config: RedBearModelConfig):
+ """根据 provider 创建客户端"""
+ provider = config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ return Ark(api_key=config.api_key, base_url=config.base_url)
+ else:
+ raise BusinessException(
+ f"不支持的视频生成提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ def generate(
+ self,
+ prompt: str,
+ image_url: Optional[str] = None,
+ first_frame_url: Optional[str] = None,
+ last_frame_url: Optional[str] = None,
+ reference_images: Optional[list] = None,
+ draft_task_id: Optional[str] = None,
+ duration: Optional[int] = None,
+ frames: Optional[int] = None,
+ ratio: Optional[str] = None,
+ resolution: Optional[str] = None,
+ generate_audio: bool = False,
+ watermark: bool = False,
+ camera_fixed: bool = False,
+ seed: Optional[int] = None,
+ return_last_frame: bool = False,
+ service_tier: str = "default",
+ execution_expires_after: Optional[int] = None,
+ draft: bool = False,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """
+ 生成视频
+
+ Args:
+ prompt: 提示词
+ image_url: 首帧图片URL(图生视频-基于首帧)
+ first_frame_url: 首帧图片URL(图生视频-基于首尾帧)
+ last_frame_url: 尾帧图片URL(图生视频-基于首尾帧)
+ reference_images: 参考图片URL列表(图生视频-基于参考图)
+ draft_task_id: Draft任务ID(基于Draft生成正式视频)
+ duration: 视频时长(秒),与frames二选一
+ frames: 视频帧数,与duration二选一
+ ratio: 视频比例,如 "16:9", "9:16", "adaptive"
+ resolution: 视频分辨率,如 "720p", "1080p"
+ generate_audio: 是否生成音频
+ watermark: 是否添加水印
+ camera_fixed: 是否固定镜头
+ seed: 随机种子
+ return_last_frame: 是否返回最后一帧
+ service_tier: 服务层级,"default" 或 "flex"(离线推理)
+ execution_expires_after: 任务过期时间(秒)
+ draft: 是否生成样片
+ **kwargs: 其他参数
+
+ Returns:
+ 生成结果(包含任务ID,需要轮询获取结果)
+ """
+ provider = self._config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ content = [{"type": "text", "text": prompt}]
+
+ if draft_task_id:
+ content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
+ else:
+ if image_url:
+ content.append({"type": "image_url", "image_url": {"url": image_url}})
+
+ if first_frame_url:
+ content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
+ if last_frame_url:
+ content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
+
+ if reference_images:
+ for ref_url in reference_images:
+ content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
+
+ params = {"model": self._config.model_name, "content": content, "watermark": watermark}
+
+ if duration:
+ params["duration"] = duration
+ if frames:
+ params["frames"] = frames
+ if ratio:
+ params["ratio"] = ratio
+ if resolution:
+ params["resolution"] = resolution
+ if generate_audio:
+ params["generate_audio"] = generate_audio
+ if camera_fixed:
+ params["camera_fixed"] = camera_fixed
+ if seed is not None:
+ params["seed"] = seed
+ if return_last_frame:
+ params["return_last_frame"] = return_last_frame
+ if service_tier != "default":
+ params["service_tier"] = service_tier
+ if execution_expires_after:
+ params["execution_expires_after"] = execution_expires_after
+ if draft:
+ params["draft"] = draft
+
+ params.update(kwargs)
+ response = self._client.content_generation.tasks.create(**params)
+ else:
+ raise BusinessException(
+ f"不支持的提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ return response.model_dump() if hasattr(response, 'model_dump') else response
+
+ async def agenerate(
+ self,
+ prompt: str,
+ image_url: Optional[str] = None,
+ duration: Optional[int] = None,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """异步生成视频"""
+ return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
+
+ def get_task_status(self, task_id: str) -> Dict[str, Any]:
+ """
+ 查询视频生成任务状态
+
+ Args:
+ task_id: 任务ID
+
+ Returns:
+ 任务状态信息
+ """
+ provider = self._config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ response = self._client.content_generation.tasks.get(task_id=task_id)
+ return response.model_dump() if hasattr(response, 'model_dump') else response
+ else:
+ raise BusinessException(
+ f"不支持的提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
+ """异步查询任务状态"""
+ return self.get_task_status(task_id)
+
+ def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
+ """
+ 查询视频生成任务列表
+
+ Args:
+ page_size: 每页数量
+ status: 任务状态筛选,如 "succeeded", "failed", "pending"
+ **kwargs: 其他参数
+
+ Returns:
+ 任务列表
+ """
+ provider = self._config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ params = {"page_size": page_size}
+ if status:
+ params["status"] = status
+ params.update(kwargs)
+ response = self._client.content_generation.tasks.list(**params)
+ return response.model_dump() if hasattr(response, 'model_dump') else response
+ else:
+ raise BusinessException(
+ f"不支持的提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
+
+ def delete_task(self, task_id: str) -> None:
+ """
+ 删除或取消视频生成任务
+
+ Args:
+ task_id: 任务ID
+ """
+ provider = self._config.provider.lower()
+
+ if provider == ModelProvider.VOLCANO:
+ self._client.content_generation.tasks.delete(task_id=task_id)
+ else:
+ raise BusinessException(
+ f"不支持的提供商: {provider}",
+ code=BizCode.PROVIDER_NOT_SUPPORTED
+ )
diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml
new file mode 100644
index 00000000..24609f5a
--- /dev/null
+++ b/api/app/core/models/scripts/volcano_models.yaml
@@ -0,0 +1,334 @@
+provider: volcano
+models:
+# Doubao-Seed 2.0 系列
+- name: doubao-seed-2-0-pro-260215
+ type: chat
+ provider: volcano
+ description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-2-0-lite-260215
+ type: chat
+ provider: volcano
+ description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-2-0-mini-260215
+ type: chat
+ provider: volcano
+ description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-2-0-code-preview-260215
+ type: chat
+ provider: volcano
+ description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ - 代码模型
+ logo: volcano
+
+# Doubao-Seed 1.x 系列
+- name: doubao-seed-1-8-251228
+ type: chat
+ provider: volcano
+ description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-1-6-251015
+ type: chat
+ provider: volcano
+ description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-1-6-lite-251015
+ type: chat
+ provider: volcano
+ description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-1-6-flash-250828
+ type: chat
+ provider: volcano
+ description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-seed-code-preview-251028
+ type: chat
+ provider: volcano
+ description: 面向Agentic编程任务进行了深度优化。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ - 代码模型
+ logo: volcano
+
+- name: doubao-seed-1-6-vision-250815
+ type: chat
+ provider: volcano
+ description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 大语言模型
+ - 多模态模型
+ logo: volcano
+
+# Doubao 1.5 系列
+- name: doubao-1-5-vision-pro-32k-250115
+ type: chat
+ provider: volcano
+ description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 大语言模型
+ - 多模态模型
+ logo: volcano
+
+- name: doubao-1-5-pro-32k-250115
+ type: chat
+ provider: volcano
+ description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
+ is_deprecated: false
+ is_official: true
+ capability: []
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+- name: doubao-1-5-lite-32k-250115
+ type: chat
+ provider: volcano
+ description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
+ is_deprecated: false
+ is_official: true
+ capability: []
+ is_omni: false
+ tags:
+ - 大语言模型
+ logo: volcano
+
+# Doubao-Seedance 视频生成系列
+- name: doubao-seedance-1-5-pro-251215
+ type: video
+ provider: volcano
+ description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 视频生成
+ logo: volcano
+
+- name: doubao-seedance-1-0-pro-250528
+ type: video
+ provider: volcano
+ description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 视频生成
+ logo: volcano
+
+- name: doubao-seedance-1-0-pro-fast-251015
+ type: video
+ provider: volcano
+ description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 视频生成
+ logo: volcano
+
+- name: doubao-seedance-1-0-lite-i2v-250428
+ type: video
+ provider: volcano
+ description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 视频生成
+ - 图生视频
+ logo: volcano
+
+- name: doubao-seedance-1-0-lite-t2v-250428
+ type: video
+ provider: volcano
+ description: 基于文本提示词生成视频
+ is_deprecated: false
+ is_official: true
+ capability: []
+ is_omni: false
+ tags:
+ - 视频生成
+ - 文生视频
+ logo: volcano
+
+# Doubao-Seedream 图像生成系列
+- name: doubao-seedream-5-0-260128
+ type: image
+ provider: volcano
+ description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 图像生成
+ logo: volcano
+
+- name: doubao-seedream-4-5-251128
+ type: image
+ provider: volcano
+ description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 图像生成
+ logo: volcano
+
+- name: doubao-seedream-4-0-250828
+ type: image
+ provider: volcano
+ description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ is_omni: false
+ tags:
+ - 图像生成
+ logo: volcano
+
+- name: doubao-seedream-3-0-t2i-250415
+ type: image
+ provider: volcano
+ description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
+ is_deprecated: false
+ is_official: true
+ capability: []
+ is_omni: false
+ tags:
+ - 图像生成
+ - 文生图
+ logo: volcano
+
+# Doubao 翻译系列
+- name: doubao-seed-translation-250915
+ type: chat
+ provider: volcano
+ description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens
+ is_deprecated: false
+ is_official: true
+ capability: []
+ is_omni: false
+ tags:
+ - 翻译模型
+ logo: volcano
+
+# Doubao Embedding 系列
+- name: doubao-embedding-vision-251215
+ type: embedding
+ provider: volcano
+ description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
+ is_deprecated: false
+ is_official: true
+ capability:
+ - vision
+ - video
+ is_omni: false
+ tags:
+ - 向量模型
+ - 多模态模型
+ logo: volcano
diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py
index 198d1473..386920e0 100644
--- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py
+++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower())
- # self.embeddings = XinferenceEmbeddings(
- # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
- # model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
- # )
- # Remove debug printing to avoid leaking sensitive information
- # print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
+
+ # 初始化 Embedding 模型(自动支持火山引擎多模态)
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
model_name=embedding_config.model_name,
provider=embedding_config.provider,
api_key=embedding_config.api_key,
base_url=embedding_config.api_base
))
- # self.reranker = XinferenceRerank(
- # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
- # model_uid="bge-reranker-large"
- # )
- # Remove debug printing to avoid leaking sensitive information
- # print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
+ self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
+
self.reranker = RedBearRerank(RedBearModelConfig(
model_name=reranker_config.model_name,
provider=reranker_config.provider,
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks]
- embeddings = self.embeddings.embed_documents(list(texts))
+ if self.is_multimodal_embedding:
+ # 火山引擎多模态 Embedding
+ embeddings = self.embeddings.embed_batch(texts)
+ else:
+ embeddings = self.embeddings.embed_documents(list(texts))
self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
updated count.
"""
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
- chunk.vector = self.embeddings.embed_query(chunk.page_content)
+ if self.is_multimodal_embedding:
+ # 火山引擎多模态 Embedding
+ chunk.vector = self.embeddings.embed_text(chunk.page_content)
+ else:
+ chunk.vector = self.embeddings.embed_query(chunk.page_content)
body = {
"script": {
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
"""Search the nearest neighbors to a vector."""
- query_vector = self.embeddings.embed_query(query)
+ if self.is_multimodal_embedding:
+ # 火山引擎多模态 Embedding
+ query_vector = self.embeddings.embed_text(query)
+ else:
+ query_vector = self.embeddings.embed_query(query)
top_k = kwargs.get("top_k", 1024)
score_threshold = float(kwargs.get("score_threshold") or 0.3)
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
diff --git a/api/app/core/storage/base.py b/api/app/core/storage/base.py
index 8ab0fcde..09824c3f 100644
--- a/api/app/core/storage/base.py
+++ b/api/app/core/storage/base.py
@@ -109,17 +109,13 @@ class StorageBackend(ABC):
pass
@abstractmethod
- async def get_url(self, file_key: str, expires: int = 3600) -> str:
- """
- Get an access URL for the file.
-
- Args:
- file_key: Unique identifier for the file in the storage system.
- expires: URL validity period in seconds (default: 1 hour).
-
- Returns:
- URL for accessing the file.
- """
+ async def get_url(
+ self,
+ file_key: str,
+ expires: int = 3600,
+ file_name: Optional[str] = None
+ ) -> str:
+ """Get an access URL for the file."""
pass
async def get_permanent_url(self, file_key: str) -> Optional[str]:
diff --git a/api/app/core/storage/local.py b/api/app/core/storage/local.py
index 4b8ae829..13adfc20 100644
--- a/api/app/core/storage/local.py
+++ b/api/app/core/storage/local.py
@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
cause=e,
)
- async def get_url(self, file_key: str, expires: int = 3600) -> str:
+ async def get_url(
+ self,
+ file_key: str,
+ expires: int = 3600,
+ file_name: Optional[str] = None
+ ) -> str:
"""
Get an access URL for the file.
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
Args:
file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (not used for local storage).
+ file_name: If set, adds Content-Disposition: attachment to force download.
Returns:
A relative URL path for accessing the file.
diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py
index 27669ffa..1db86fef 100644
--- a/api/app/core/storage/oss.py
+++ b/api/app/core/storage/oss.py
@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
import io
import logging
+import urllib.parse
from typing import AsyncIterator, Optional
import oss2
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
return False
- async def get_url(self, file_key: str, expires: int = 3600) -> str:
+ async def get_url(
+ self,
+ file_key: str,
+ expires: int = 3600,
+ file_name: Optional[str] = None,
+ ) -> str:
"""
Get a presigned URL for accessing the file.
Args:
file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour).
+ file_name: If set, adds Content-Disposition: attachment to force download.
Returns:
A presigned URL for accessing the file.
"""
try:
- url = self.bucket.sign_url("GET", file_key, expires)
+ params = {}
+ if file_name:
+ filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
+ params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
+ url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url
except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
- # Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str:
diff --git a/api/app/core/storage/s3.py b/api/app/core/storage/s3.py
index c7b33ffe..f156f4a7 100644
--- a/api/app/core/storage/s3.py
+++ b/api/app/core/storage/s3.py
@@ -6,6 +6,7 @@ using the boto3 SDK.
"""
import io
+import urllib.parse
import logging
from typing import AsyncIterator, Optional
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
return False
- async def get_url(self, file_key: str, expires: int = 3600) -> str:
+ async def get_url(
+ self,
+ file_key: str,
+ expires: int = 3600,
+ file_name: Optional[str] = None,
+ ) -> str:
"""
Get a presigned URL for accessing the file.
Args:
file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour).
+ file_name: If set, adds Content-Disposition: attachment to force download.
Returns:
A presigned URL for accessing the file.
"""
try:
+ params = {"Bucket": self.bucket_name, "Key": file_key}
+ if file_name:
+ filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
+ params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.client.generate_presigned_url(
"get_object",
- Params={
- "Bucket": self.bucket_name,
- "Key": file_key,
- },
+ Params=params,
ExpiresIn=expires,
)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url
except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
- # Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str:
diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py
index 49321b89..2e24d085 100644
--- a/api/app/core/workflow/adapters/base_adapter.py
+++ b/api/app/core/workflow/adapters/base_adapter.py
@@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field
-from app.core.workflow.adapters.errors import ExceptionDefineition
+from app.core.workflow.adapters.errors import ExceptionDefinition
from app.schemas.workflow_schema import (
EdgeDefinition,
NodeDefinition,
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
- warnings: list[ExceptionDefineition] = Field(default_factory=list)
- errors: list[ExceptionDefineition] = Field(default_factory=list)
+ warnings: list[ExceptionDefinition] = Field(default_factory=list)
+ errors: list[ExceptionDefinition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel):
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
- warnings: list[ExceptionDefineition] = Field(default_factory=list)
- errors: list[ExceptionDefineition] = Field(default_factory=list)
+ warnings: list[ExceptionDefinition] = Field(default_factory=list)
+ errors: list[ExceptionDefinition] = Field(default_factory=list)
class BasePlatformAdapter(ABC):
diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py
index 467beb07..4fa9508b 100644
--- a/api/app/core/workflow/adapters/dify/converter.py
+++ b/api/app/core/workflow/adapters/dify/converter.py
@@ -9,9 +9,9 @@ from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import (
- UnsupportVariableType,
- UnknowModelWarning,
- ExceptionDefineition,
+ UnsupportedVariableType,
+ UnknownModelWarning,
+ ExceptionDefinition,
ExceptionType
)
from app.core.workflow.nodes.assigner.config import AssignmentItem
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
HttpFormData,
HttpTimeOutConfig,
HttpRetryConfig,
- HttpErrorDefaultTamplete,
+ HttpErrorDefaultTemplate,
HttpErrorHandleConfig
)
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
try:
return config.model_validate(value)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
var_selector = mapping.get(var_selector, var_selector)
return var_selector
- def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
+ def _process_list_variable_literal(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)):
return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
var_type = self.variable_type_map(var["type"])
if not var_type:
self.errors.append(
- UnsupportVariableType(
+ UnsupportedVariableType(
scope=node["id"],
name=var["variable"],
var_type=var["type"],
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
if var_type in ["file", "array[file]"]:
self.errors.append(
- ExceptionDefineition(
+ ExceptionDefinition(
type=ExceptionType.VARIABLE,
node_id=node["id"],
node_name=node_data["title"],
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
- UnknowModelWarning(
+ UnknownModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
)
result = QuestionClassifierNodeConfig.model_construct(
- input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
+ input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories,
).model_dump()
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
- UnknowModelWarning(
+ UnknownModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
- context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
+ context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
)
)
vision = node_data["vision"]["enabled"]
- vision_input = self._process_list_variable_litearl(
+ vision_input = self._process_list_variable_literal(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
result = LLMNodeConfig.model_construct(
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
conditions.append(
LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]),
- left=self._process_list_variable_litearl(condition["variable_selector"]),
+ left=self._process_list_variable_literal(condition["variable_selector"]),
right=self.trans_variable_format(
right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
- right_value = self._process_list_variable_litearl(variable.get("value", ""))
+ right_value = self._process_list_variable_literal(variable.get("value", ""))
else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append(
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
result = IterationNodeConfig.model_construct(
- input=self._process_list_variable_litearl(node_data["iterator_selector"]),
+ input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
- output=self._process_list_variable_litearl(node_data["output_selector"]),
+ output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
).model_dump()
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
continue
assignments.append(
AssignmentItem(
- variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
- value=self._process_list_variable_litearl(
+ variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
+ value=self._process_list_variable_literal(
assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"])
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
input_variables.append(
InputVariable.model_construct(
name=input_variable["variable"],
- variable=self._process_list_variable_litearl(input_variable["value_selector"]),
+ variable=self._process_list_variable_literal(input_variable["value_selector"]),
)
)
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
else:
if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or
- self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
+ self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
else:
body_content = ""
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
default_header = var["value"]
elif var["key"] == "status_code":
default_status_code = var["value"]
- default_value = HttpErrorDefaultTamplete(
+ default_value = HttpErrorDefaultTemplate(
body=default_body,
headers=default_header,
status_code=default_status_code,
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"],
- value=self._process_list_variable_litearl(variable["value_selector"])
+ value=self._process_list_variable_literal(variable["value_selector"])
))
result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"],
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
result = KnowledgeRetrievalNodeConfig.model_construct(
- query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
+ query=self._process_list_variable_literal(node_data["query_variable_selector"]),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
- UnknowModelWarning(
+ UnknownModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
)
)
result = ParameterExtractorNodeConfig.model_construct(
- text=self._process_list_variable_litearl(node_data["query"]),
+ text=self._process_list_variable_literal(node_data["query"]),
params=params,
prompt=node_data.get("instruction")
).model_dump()
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [
- self._process_list_variable_litearl(variable)
+ self._process_list_variable_literal(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [
- self._process_list_variable_litearl(variable)
+ self._process_list_variable_literal(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py
index 10397ad0..abd95408 100644
--- a/api/app/core/workflow/adapters/dify/dify_adapter.py
+++ b/api/app/core/workflow/adapters/dify/dify_adapter.py
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
WorkflowParserResult
)
from app.core.workflow.adapters.dify.converter import DifyConverter
-from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
+from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import (
NodeDefinition,
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if not all(field in self.config for field in require_fields):
return False
if self.config.get("app", {}).get("mode") == "workflow":
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.PLATFORM,
detail="workflow mode is not supported"
))
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edge = self._convert_edge(edge)
if edge:
self.edges.append(edge)
- #
+
for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable)
if variable:
self.conv_variables.append(con_var)
- #
+
# for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables)
# conv_variables.append(variable)
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"y": node["position"]["y"] + position["y"]
}
self.errors.append(
- ExceptionDefineition(
+ ExceptionDefinition(
type=ExceptionType.NODE,
node_id=node_id,
detail="parent cycle node not found"
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_data = node["data"]
converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
))
return converter(node)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try:
-
source = edge["source"]
target = edge["target"]
label = None
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
label=label,
)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}",
))
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
description=variable.get("description")
)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}",
diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py
index c0340a5e..cb743c68 100644
--- a/api/app/core/workflow/adapters/errors.py
+++ b/api/app/core/workflow/adapters/errors.py
@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
UNKNOWN = "unknown"
-class ExceptionDefineition(BaseModel):
+class ExceptionDefinition(BaseModel):
type: ExceptionType
detail: str
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
name: str | None = None
-class UnknowModelWarning(ExceptionDefineition):
+class UnknownModelWarning(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name):
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
)
-class UnknowError(ExceptionDefineition):
+class UnknownError(ExceptionDefinition):
type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs)
-class UnsupportPlatform(ExceptionDefineition):
+class UnsupportedPlatform(ExceptionDefinition):
type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str):
- super().__init__(detail=f"Unsupport platform {platform}")
+ super().__init__(detail=f"Unsupported platform {platform}")
-class UnsupportVariableType(ExceptionDefineition):
+class UnsupportedVariableType(ExceptionDefinition):
type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs):
- super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
+ super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
-class InvalidConfiguration(ExceptionDefineition):
+class InvalidConfiguration(ExceptionDefinition):
type: ExceptionType = ExceptionType.CONFIG
def __init__(self):
super().__init__(detail="Invalid workflow configuration format")
-class UnsupportNodeType(ExceptionDefineition):
+class UnsupportedNodeType(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str):
- super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
+ super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")
diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py
index 3516cb58..a2608a01 100644
--- a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py
+++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py
@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter,
WorkflowParserResult
)
-from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
+from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try:
node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN:
- self.errors.append(UnsupportNodeType(
+ self.errors.append(UnsupportedNodeType(
node_id=node_id,
node_type=node["type"]
))
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
return NodeDefinition(**node)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE,
node_id=node_id,
node_name=node_name,
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found"
))
return None
return EdgeDefinition(**edge)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}"
))
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try:
return VariableDefinition(**variable)
except Exception as e:
- self.warnings.append(ExceptionDefineition(
+ self.warnings.append(ExceptionDefinition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}"
diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py
index 031c7025..e96e0bf2 100644
--- a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py
+++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py
@@ -1,6 +1,6 @@
# -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter
-from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
+from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import (
StartNodeConfig,
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
try:
return config_cls.model_validate(value)
except Exception as e:
- self.errors.append(ExceptionDefineition(
+ self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py
index 674c45d0..daef6e82 100644
--- a/api/app/core/workflow/engine/graph_builder.py
+++ b/api/app/core/workflow/engine/graph_builder.py
@@ -7,7 +7,7 @@ import re
import uuid
from collections import defaultdict
from functools import lru_cache
-from typing import Any, Iterable
+from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
@@ -41,48 +41,31 @@ class GraphBuilder:
self,
workflow_config: dict[str, Any],
stream: bool = False,
- subgraph: bool = False,
+ cycle: str = '',
variable_pool: VariablePool | None = None
):
self.workflow_config = workflow_config
self.stream = stream
- self.subgraph = subgraph
+ self.cycle = cycle
self.start_node_id: str | None = None
- self.node_map = {node["id"]: node for node in self.nodes}
+ self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {}
- self._find_upstream_activation_dep = lru_cache(
- maxsize=len(self.nodes) * 2
- )(self._find_upstream_activation_dep)
+ self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
if variable_pool:
self.variable_pool = variable_pool
else:
self.variable_pool = VariablePool()
- self.graph = StateGraph(WorkflowState)
- self.add_nodes()
- self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
- self.end_nodes = [
- node
- for node in self.nodes
- if node.get("type") == "end" and node.get("id") in self.reachable_nodes
- ]
- self.add_edges()
- # EDGES MUST BE ADDED AFTER NODES ARE ADDED.
-
+ self.graph: StateGraph | None = None
+ self.nodes: list = []
+ self.edges: list = []
+ self.reachable_nodes: set[str] | None = None
+ self.end_nodes: list[dict] = []
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
- self._build_reverse_adj()
- self._analyze_end_node_output()
-
- @property
- def nodes(self) -> list[dict[str, Any]]:
- return self.workflow_config.get("nodes", [])
-
- @property
- def edges(self) -> list[dict[str, Any]]:
- return self.workflow_config.get("edges", [])
+ self._adj: dict[str, list[str]] = defaultdict(list)
def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID.
@@ -108,13 +91,14 @@ class GraphBuilder:
result[node[0]].append(node[1])
return result
- def _build_reverse_adj(self):
+ def _build_adj(self):
for edge in self.edges:
if edge["source"] not in self.reachable_nodes:
continue
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
+ self._adj[edge.get("source")].append(edge["target"])
def _find_upstream_activation_dep(
self,
@@ -302,22 +286,13 @@ class GraphBuilder:
"""
for node in self.nodes:
node_type = node.get("type")
- if node_type == NodeType.NOTES:
- continue
node_id = node.get("id")
- cycle_node = node.get("cycle")
- if cycle_node:
- # Nodes within a loop subgraph are constructed by CycleGraphNode
- if not self.subgraph:
- continue
-
- # Record start and end node IDs
- if node_type in [NodeType.START, NodeType.CYCLE_START]:
- self.start_node_id = node_id
+ if node_id not in self.reachable_nodes:
+ continue
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
- node_instance = NodeFactory.create_node(node, self.workflow_config)
+ node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES:
@@ -390,6 +365,8 @@ class GraphBuilder:
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
+ if source not in self.reachable_nodes or target not in self.reachable_nodes:
+ continue
condition = edge.get("condition")
edge_type = edge.get("type")
@@ -411,11 +388,12 @@ class GraphBuilder:
# Add conditional edges
for source_node, branches in conditional_edges.items():
def make_router(src, branch_list):
- """reate a router function for each source node that routes to a NOP node for later merging."""
+ """Create a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets):
def node(s):
- # NOTE: NOP NODE MUST NOT MODIFY STATE
+ # NOTE: NOP NODE USED FOR ROUTING ONLY.
+ # MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
return {
"activate": {
node_id: s["activate"][node_name]
@@ -502,14 +480,52 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node
- for end_node in self.end_nodes:
- end_node_id = end_node.get("id")
- if end_node_id:
- self.graph.add_edge(end_node_id, END)
- logger.debug(f"Added edge: {end_node_id} -> END")
+ for node in self.reachable_nodes:
+ if not self._adj[node]:
+ self.graph.add_edge(node, END)
return
def build(self) -> CompiledStateGraph:
+ nodes = self.workflow_config.get("nodes", [])
+ edges = self.workflow_config.get("edges", [])
+
+ for node in nodes:
+ if (node.get("cycle") or '') == self.cycle:
+ node_type = node.get("type")
+ if node_type in [NodeType.START, NodeType.CYCLE_START]:
+ self.start_node_id = node.get("id")
+ elif node_type == NodeType.NOTES:
+ continue
+ self.nodes.append(node)
+ self.node_map[node.get("id")] = node
+
+ for edge in edges:
+ source_in = edge.get("source") in self.node_map
+ target_in = edge.get("target") in self.node_map
+ if source_in ^ target_in:
+ raise ValueError(
+ f"Cycle node is connected to external node, "
+ f"source: {edge.get('source')}, target: {edge.get('target')}"
+ )
+
+ if source_in and target_in:
+ self.edges.append(edge)
+
+ self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
+ self.end_nodes = [
+ node
+ for node in self.nodes
+ if node.get("type") == "end" and node.get("id") in self.reachable_nodes
+ ]
+ self._build_adj()
+ self._find_upstream_activation_dep: Callable = lru_cache(
+ maxsize=len(self.nodes)*2
+ )(self._find_upstream_activation_dep)
+
+ self.graph = StateGraph(WorkflowState)
+ self.add_nodes()
+ self.add_edges()
+
+ self._analyze_end_node_output()
checkpointer = InMemorySaver()
- self.graph = self.graph.compile(checkpointer=checkpointer)
- return self.graph
+ return self.graph.compile(checkpointer=checkpointer)
diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py
index e5a03c1c..be0c957a 100644
--- a/api/app/core/workflow/engine/result_builder.py
+++ b/api/app/core/workflow/engine/result_builder.py
@@ -2,6 +2,7 @@
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33
+from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
def build_final_output(
self,
result: dict,
+ execution_context: ExecutionContext,
variable_pool: VariablePool,
elapsed_time: float,
final_output: str,
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
- "node_outputs" (dict): Outputs of executed nodes.
- "messages" (list): Conversation messages exchanged during execution.
- "error" (str, optional): Error message if any node failed.
+ execution_context (ExecutionContext): The execution context containing metadata like
+ execution ID, workspace ID, and user ID.)
variable_pool (VariablePool): Variable Pool
elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
"""
node_outputs = result.get("node_outputs", {})
token_usage = self.aggregate_token_usage(node_outputs)
- conversation_id = variable_pool.get_value("sys.conversation_id")
+ conversation_vars = {}
+ sys_vars = {}
+
+ if variable_pool:
+ conversation_vars = variable_pool.get_all_conversation_vars()
+ sys_vars = variable_pool.get_all_system_vars()
return {
"status": "completed" if success else "failed",
"output": final_output,
"variables": {
- "conv": variable_pool.get_all_conversation_vars(),
- "sys": variable_pool.get_all_system_vars()
+ "conv": conversation_vars,
+ "sys": sys_vars
},
"node_outputs": node_outputs,
"messages": result.get("messages", []),
- "conversation_id": conversation_id,
+ "conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error"),
diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py
index e4bf65af..036ce0e8 100644
--- a/api/app/core/workflow/engine/runtime_schema.py
+++ b/api/app/core/workflow/engine/runtime_schema.py
@@ -12,14 +12,29 @@ class ExecutionContext(BaseModel):
execution_id: str
workspace_id: str
user_id: str
+ conversation_id: str
+ memory_storage_type: str
+ user_rag_memory_id: str
checkpoint_config: RunnableConfig
@classmethod
- def create(cls, execution_id: str, workspace_id: str, user_id: str):
+ def create(
+ cls,
+ execution_id: str,
+ workspace_id: str,
+ user_id: str,
+ conversation_id: str,
+ memory_storage_type: str,
+ user_rag_memory_id: str
+ ):
return cls(
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id,
+ conversation_id=conversation_id,
+ memory_storage_type=memory_storage_type,
+ user_rag_memory_id=user_rag_memory_id,
+
checkpoint_config=RunnableConfig(
configurable={
"thread_id": uuid.uuid4(),
diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py
index 0a4a1463..2da0d3a8 100644
--- a/api/app/core/workflow/engine/state_manager.py
+++ b/api/app/core/workflow/engine/state_manager.py
@@ -33,6 +33,8 @@ class WorkflowState(dict):
"workspace_id",
"user_id",
"activate",
+ "memory_storage_type",
+ "user_rag_memory_id"
})
__optional_keys__ = frozenset({
"error",
@@ -62,6 +64,9 @@ class WorkflowState(dict):
# node activate status
activate: Annotated[dict[str, bool], merge_activate_state]
+ memory_storage_type: str
+ user_rag_memory_id: str
+
class WorkflowStateManager:
def create_initial_state(
@@ -85,7 +90,9 @@ class WorkflowStateManager:
looping=0,
activate={
start_node_id: True
- }
+ },
+ memory_storage_type=execution_context.memory_storage_type,
+ user_rag_memory_id=execution_context.user_rag_memory_id
)
@staticmethod
diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py
index 6685a49e..dcc92fdb 100644
--- a/api/app/core/workflow/engine/stream_output_coordinator.py
+++ b/api/app/core/workflow/engine/stream_output_coordinator.py
@@ -3,7 +3,7 @@
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11
import re
-from queue import Queue
+from collections import deque
from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr
@@ -256,7 +256,7 @@ class StreamOutputCoordinator:
def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None
- self.output_queue: Queue = Queue()
+ self.output_queue: deque[str] = deque()
self.processed_outputs = []
def initialize_end_outputs(
@@ -266,7 +266,7 @@ class StreamOutputCoordinator:
self.end_outputs = end_node_map
self.processed_outputs = []
self.activate_end = None
- self.output_queue = Queue()
+ self.output_queue = deque()
@property
def current_activate_end_info(self):
@@ -296,13 +296,13 @@ class StreamOutputCoordinator:
scope (str): The node ID or scope that has completed execution.
status (str | None): Optional status of the node (used for branch/control nodes).
"""
- for node in self.end_outputs.keys():
+ for node in self.end_outputs:
self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and node not in self.processed_outputs:
- self.output_queue.put(node)
+ self.output_queue.append(node)
self.processed_outputs.append(node)
- if self.activate_end is None and not self.output_queue.empty():
- self.activate_end = self.output_queue.get_nowait()
+ if self.activate_end is None and self.output_queue:
+ self.activate_end = self.output_queue.popleft()
async def emit_activate_chunk(
self,
@@ -414,8 +414,8 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event
- if not self.output_queue.empty():
- self.activate_end = self.output_queue.get_nowait()
+ if self.output_queue:
+ self.activate_end = self.output_queue.popleft()
# Move to next active End node if current one is done
if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]
diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py
index cf6f4a7b..60f1257e 100644
--- a/api/app/core/workflow/engine/variable_pool.py
+++ b/api/app/core/workflow/engine/variable_pool.py
@@ -13,7 +13,7 @@ from pydantic import BaseModel
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
-from app.core.workflow.variable.variable_objects import T, create_variable_instance
+from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
logger = logging.getLogger(__name__)
@@ -373,6 +373,16 @@ class VariablePool:
def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables)
+ def is_file_variable(self, selector):
+ variable_struct = self.get_instance(selector, default=None, strict=False)
+ if variable_struct is None:
+ return False
+ if isinstance(variable_struct, FileVariable):
+ return True
+ elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
+ return True
+ return False
+
def to_dict(self) -> dict[str, Any]:
"""导出为字典
diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py
index c9ed6e65..0a820826 100644
--- a/api/app/core/workflow/executor.py
+++ b/api/app/core/workflow/executor.py
@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 13:51
import datetime
+import time
import logging
from typing import Any
@@ -82,13 +83,15 @@ class WorkflowExecutor:
CompiledStateGraph: The compiled and ready-to-run state graph.
"""
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
+ start_time = time.time()
builder = GraphBuilder(
self.workflow_config,
stream=stream,
)
+
+ self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool
- self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler(
@@ -96,7 +99,8 @@ class WorkflowExecutor:
variable_pool=self.variable_pool,
execution_id=self.execution_context.execution_id
)
- logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
+ logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
+ f"cost: {time.time() - start_time:.4f}s")
return self.graph
@@ -134,94 +138,12 @@ class WorkflowExecutor:
return event.get("data")
return self.result_builder.build_final_output(
{"error": "Workflow execution did not end as expected"},
+ self.execution_context,
self.variable_pool,
(datetime.datetime.now() - start).total_seconds(),
"",
success=False
)
- # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
- #
- # start_time = datetime.datetime.now()
- #
- # # Execute the workflow
- # try:
- # # Build the workflow graph
- # graph = self.build_graph()
- #
- # # Initialize the variable pool with input data
- # await self.variable_initializer.initialize(
- # variable_pool=self.variable_pool,
- # input_data=input_data,
- # execution_context=self.execution_context
- # )
- # initial_state = self.state_manager.create_initial_state(
- # workflow_config=self.workflow_config,
- # input_data=input_data,
- # execution_context=self.execution_context,
- # start_node_id=self.start_node_id
- # )
- #
- # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
- #
- # # Aggregate output from all End nodes
- # full_content = ''
- # for end_id in self.stream_coordinator.end_outputs.keys():
- # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
- #
- # # Append messages for user and assistant
- # if input_data.get("files"):
- # result["messages"].extend(
- # [
- # {
- # "role": "user",
- # "content": input_data.get("message", '')
- # },
- # {
- # "role": "user",
- # "content": input_data.get("files")
- # },
- # {
- # "role": "assistant",
- # "content": full_content
- # }
- # ]
- # )
- # else:
- # result["messages"].extend(
- # [
- # {
- # "role": "user",
- # "content": input_data.get("message", '')
- # },
- # {
- # "role": "assistant",
- # "content": full_content
- # }
- # ]
- # )
- # # Calculate elapsed time
- # end_time = datetime.datetime.now()
- # elapsed_time = (end_time - start_time).total_seconds()
- #
- # logger.info(
- # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
- #
- # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
- #
- # except Exception as e:
- # end_time = datetime.datetime.now()
- # elapsed_time = (end_time - start_time).total_seconds()
- #
- # logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
- # exc_info=True)
- # return {
- # "status": "failed",
- # "error": str(e),
- # "output": None,
- # "node_outputs": {},
- # "elapsed_time": elapsed_time,
- # "token_usage": None
- # }
async def execute_stream(
self,
@@ -255,7 +177,7 @@ class WorkflowExecutor:
"data": {
"execution_id": self.execution_context.execution_id,
"workspace_id": self.execution_context.workspace_id,
- "conversation_id": input_data.get("conversation_id"),
+ "conversation_id": self.execution_context.conversation_id,
"timestamp": int(start_time.timestamp() * 1000)
}
}
@@ -376,6 +298,7 @@ class WorkflowExecutor:
"event": "workflow_end",
"data": self.result_builder.build_final_output(
result,
+ self.execution_context,
self.variable_pool,
elapsed_time,
full_content,
@@ -396,6 +319,7 @@ class WorkflowExecutor:
"event": "workflow_end",
"data": self.result_builder.build_final_output(
result,
+ self.execution_context,
self.variable_pool,
elapsed_time,
full_content,
@@ -409,7 +333,9 @@ async def execute_workflow(
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
- user_id: str
+ user_id: str,
+ memory_storage_type: str,
+ user_rag_memory_id: str
) -> dict[str, Any]:
"""
Execute a workflow (convenience function, non-streaming).
@@ -420,6 +346,8 @@ async def execute_workflow(
execution_id (str): Execution ID.
workspace_id (str): Workspace ID.
user_id (str): User ID.
+ user_rag_memory_id: rag knowledge db id
+ memory_storage_type: neo4j / rag
Returns:
dict: Workflow execution result.
@@ -427,7 +355,10 @@ async def execute_workflow(
execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=workspace_id,
- user_id=user_id
+ user_id=user_id,
+ conversation_id=input_data.get("conversation_id"),
+ memory_storage_type=memory_storage_type,
+ user_rag_memory_id=user_rag_memory_id
)
executor = WorkflowExecutor(
workflow_config=workflow_config,
@@ -441,7 +372,9 @@ async def execute_workflow_stream(
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
- user_id: str
+ user_id: str,
+ memory_storage_type: str,
+ user_rag_memory_id: str
):
"""
Execute a workflow in streaming mode (convenience function).
@@ -452,6 +385,8 @@ async def execute_workflow_stream(
execution_id (str): Execution ID.
workspace_id (str): Workspace ID.
user_id (str): User ID.
+ user_rag_memory_id: rag knowledge db id
+ memory_storage_type: neo4j / rag
Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
@@ -459,7 +394,10 @@ async def execute_workflow_stream(
execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=workspace_id,
- user_id=user_id
+ user_id=user_id,
+ memory_storage_type=memory_storage_type,
+ conversation_id=input_data.get("conversation_id"),
+ user_rag_memory_id=user_rag_memory_id
)
executor = WorkflowExecutor(
workflow_config=workflow_config,
diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py
index 8959e27c..7b146a9c 100644
--- a/api/app/core/workflow/nodes/agent/node.py
+++ b/api/app/core/workflow/nodes/agent/node.py
@@ -64,9 +64,7 @@ class AgentNode(BaseNode):
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
-
-
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py
index 4c897d5a..f5bdf000 100644
--- a/api/app/core/workflow/nodes/assigner/node.py
+++ b/api/app/core/workflow/nodes/assigner/node.py
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py
index 0e3fecee..8567ebbe 100644
--- a/api/app/core/workflow/nodes/base_node.py
+++ b/api/app/core/workflow/nodes/base_node.py
@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method.
"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node.
Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id)
+ self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False).
"""
- edges = self.workflow_config.get("edges")
- under_stream_nodes = [
- edge.get("target")
- for edge in edges
- if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
- ]
- return {
- "activate": {
- node_id: self.check_activate(state)
- for node_id in under_stream_nodes
- } | {self.node_id: self.check_activate(state)}
- }
+ activate_flag = self.check_activate(state)
+
+ if self.node_type not in BRANCH_NODES:
+ activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
+ else:
+ activate = {}
+
+ activate[self.node_id] = activate_flag
+
+ return {"activate": activate}
@abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -315,8 +314,8 @@ class BaseNode(ABC):
elapsed_time = (time.time() - start_time) * 1000
- logger.info(f"Node {self.node_id} streaming execution finished, "
- f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
+ logger.debug(f"Node {self.node_id} streaming execution finished, "
+ f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow.
"""
- # Check if the node has an error edge defined
- error_edge = self._find_error_edge()
+ # # Check if the node has an error edge defined
+ # error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message
}
- if error_edge:
- # If an error edge exists, log a warning and continue to error node
- logger.warning(
- f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
- )
- return {
- "node_outputs": {
- self.node_id: node_output
- },
- "error": error_message,
- "error_node": self.node_id
- }
- else:
- # If no error edge, send the error via stream writer and stop the workflow
- writer = get_stream_writer()
- writer({
- "type": "node_error",
- **node_output
- })
- logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
- raise Exception(f"Node {self.node_id} execution failed: {error_message}")
+ # if error_edge:
+ # # If an error edge exists, log a warning and continue to error node
+ # logger.warning(
+ # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
+ # )
+ # return {
+ # "node_outputs": {
+ # self.node_id: node_output
+ # },
+ # "error": error_message,
+ # "error_node": self.node_id
+ # }
+ # else:
+ writer = get_stream_writer()
+ writer({
+ "type": "node_error",
+ **node_output
+ })
+ logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
+ raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit).
@@ -623,7 +621,6 @@ class BaseNode(ABC):
async def process_message(
api_config: ModelInfo,
content: str | dict | FileObject,
- end_user_id: str,
enable_file=False
) -> list | str | None:
provider = api_config.provider
@@ -642,10 +639,10 @@ class BaseNode(ABC):
return content
elif isinstance(content, FileObject):
- if content.content_cache.get(provider):
- return content.content_cache[provider]
+ if content.content_cache.get(f"{provider}_{api_config.is_omni}"):
+ return content.content_cache[f"{provider}_{api_config.is_omni}"]
with get_db_read() as db:
- multimodel_service = MultimodalService(db, api_config=api_config)
+ multimodal_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput(
type=content.type,
url=content.url,
@@ -654,16 +651,15 @@ class BaseNode(ABC):
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
)
file_obj.set_content(content.get_content())
- message = await multimodel_service.process_files(
- end_user_id,
+ message = await multimodal_service.process_files(
[file_obj],
)
content.set_content(file_obj.get_content())
if message:
- content.content_cache[provider] = message
+ content.content_cache[f"{provider}_{api_config.is_omni}"] = message
return message
return None
- raise TypeError(f'Unexpect input value type - {type(content)}')
+ raise TypeError(f'Unexpected input value type - {type(content)}')
@staticmethod
def process_model_output(content) -> str:
diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py
index 1e055002..d89b208b 100644
--- a/api/app/core/workflow/nodes/code/node.py
+++ b/api/app/core/workflow/nodes/code/node.py
@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py
index 71e0dbdb..fc80939f 100644
--- a/api/app/core/workflow/nodes/cycle_graph/node.py
+++ b/api/app/core/workflow/nodes/cycle_graph/node.py
@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph.
"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
-
- self.cycle_nodes = list() # Nodes belonging to this cycle
- self.cycle_edges = list() # Edges connecting nodes within the cycle
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
+ self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None
- self.build_graph()
- self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
else:
remain_edges.append(edge)
- # Update workflow_config by removing cycle nodes and internal edges
- self.workflow_config["nodes"] = [
- node for node in nodes if node.get("cycle") != self.node_id
- ]
- self.workflow_config["edges"] = remain_edges
+ # # Update workflow_config by removing cycle nodes and internal edges
+ # self.workflow_config["nodes"] = [
+ # node for node in nodes if node.get("cycle") != self.node_id
+ # ]
+ # self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution
"""
from app.core.workflow.engine.graph_builder import GraphBuilder
- self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
+
self.child_variable_pool = VariablePool()
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
- subgraph=True,
- variable_pool=self.child_variable_pool
+ variable_pool=self.child_variable_pool,
+ cycle=self.node_id
)
- self.start_node_id = builder.start_node_id
self.graph = builder.build()
+ self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises:
RuntimeError: If the node type is unsupported.
"""
+ self.build_graph()
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
+ self.build_graph()
if self.node_type == NodeType.LOOP:
yield {
"__final__": True,
diff --git a/api/app/core/workflow/nodes/document_extractor/__init__.py b/api/app/core/workflow/nodes/document_extractor/__init__.py
new file mode 100644
index 00000000..c51bc2c0
--- /dev/null
+++ b/api/app/core/workflow/nodes/document_extractor/__init__.py
@@ -0,0 +1,4 @@
+from .config import DocExtractorNodeConfig
+from .node import DocExtractorNode
+
+__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"]
diff --git a/api/app/core/workflow/nodes/document_extractor/config.py b/api/app/core/workflow/nodes/document_extractor/config.py
new file mode 100644
index 00000000..69f7f76d
--- /dev/null
+++ b/api/app/core/workflow/nodes/document_extractor/config.py
@@ -0,0 +1,18 @@
+from pydantic import Field
+from app.core.workflow.nodes.base_config import BaseNodeConfig
+
+
+class DocExtractorNodeConfig(BaseNodeConfig):
+ file_selector: str = Field(
+ ...,
+ description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}"
+ )
+
+ class Config:
+ json_schema_extra = {
+ "examples": [
+ {
+ "file_selector": "{{ sys.files }}"
+ }
+ ]
+ }
diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py
new file mode 100644
index 00000000..40641f3c
--- /dev/null
+++ b/api/app/core/workflow/nodes/document_extractor/node.py
@@ -0,0 +1,103 @@
+import logging
+from typing import Any
+
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
+from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
+from app.core.workflow.variable.base_variable import VariableType, FileObject
+from app.db import get_db_read
+from app.schemas.app_schema import FileInput, FileType, TransferMethod
+
+logger = logging.getLogger(__name__)
+
+
+def _file_object_to_file_input(f: FileObject) -> FileInput:
+ """Convert workflow FileObject to multimodal FileInput."""
+ return FileInput(
+ type=FileType.DOCUMENT,
+ transfer_method=TransferMethod(f.transfer_method),
+ url=f.url or None,
+ upload_file_id=f.file_id or None,
+ file_type=f.origin_file_type or "",
+ )
+
+
+def _normalise_files(val: Any) -> list[FileObject]:
+ if isinstance(val, FileObject):
+ return [val]
+ if isinstance(val, dict) and val.get("is_file"):
+ return [FileObject(**val)]
+ if isinstance(val, list):
+ result: list[FileObject] = []
+ for item in val:
+ if isinstance(item, FileObject):
+ result.append(item)
+ elif isinstance(item, dict) and item.get("is_file"):
+ result.append(FileObject(**item))
+ else:
+ logger.warning("Ignoring non-file entry in file list for document extractor: %r", item)
+ return result
+ return []
+
+
+class DocExtractorNode(BaseNode):
+ """Document Extractor Node.
+
+ Reads one or more file variables and extracts their text content
+ by delegating to MultimodalService._extract_document_text.
+
+ Outputs:
+ text (string) – full concatenated text of all input files
+ chunks (array[string]) – per-file extracted text
+ """
+
+ def _output_types(self) -> dict[str, VariableType]:
+ return {
+ "text": VariableType.STRING,
+ "chunks": VariableType.ARRAY_STRING,
+ }
+
+ def _extract_output(self, business_result: Any) -> Any:
+ return business_result
+
+ def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
+ return {"file_selector": self.config.get("file_selector")}
+
+ async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
+ config = DocExtractorNodeConfig(**self.config)
+
+ raw_val = self.get_variable(config.file_selector, variable_pool, strict=False)
+ if raw_val is None:
+ logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty")
+ return {"text": "", "chunks": []}
+
+ files = _normalise_files(raw_val)
+ if not files:
+ return {"text": "", "chunks": []}
+
+ chunks: list[str] = []
+ with get_db_read() as db:
+ from app.services.multimodal_service import MultimodalService
+ svc = MultimodalService(db)
+ for f in files:
+ try:
+ file_input = _file_object_to_file_input(f)
+ # Ensure URL is populated for local files
+ if not file_input.url:
+ file_input.url = await svc.get_file_url(file_input)
+ # Reuse cached bytes if already fetched
+ if f.get_content():
+ file_input.set_content(f.get_content())
+ text = await svc._extract_document_text(file_input)
+ chunks.append(text)
+ except Exception as e:
+ logger.error(
+ f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
+ exc_info=True,
+ )
+ chunks.append("")
+
+ full_text = "\n\n".join(c for c in chunks if c)
+ logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}")
+ return {"text": full_text, "chunks": chunks}
diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py
index 5c2a6c2a..02df5091 100644
--- a/api/app/core/workflow/nodes/end/config.py
+++ b/api/app/core/workflow/nodes/end/config.py
@@ -1,9 +1,7 @@
"""End 节点配置"""
-
from pydantic import Field
-from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
-from app.core.workflow.variable.base_variable import VariableType
+from app.core.workflow.nodes.base_config import BaseNodeConfig
class EndNodeConfig(BaseNodeConfig):
diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py
index 2799316a..770cf328 100644
--- a/api/app/core/workflow/nodes/end/node.py
+++ b/api/app/core/workflow/nodes/end/node.py
@@ -36,8 +36,6 @@ class EndNode(BaseNode):
Returns:
最终输出字符串
"""
- logger.info(f"节点 {self.node_id} (End) 开始执行")
-
# 获取配置的输出模板
output_template = self.config.get("output")
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
output = self._render_template(output_template, variable_pool, strict=False)
else:
output = ""
-
- # 统计信息(用于日志)
- node_outputs = state.get("node_outputs", {})
- total_nodes = len(node_outputs)
-
- logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
-
return output
diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py
index 43ab593b..529cd0b3 100644
--- a/api/app/core/workflow/nodes/enums.py
+++ b/api/app/core/workflow/nodes/enums.py
@@ -23,12 +23,13 @@ class NodeType(StrEnum):
BREAK = "break"
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
+ DOCUMENT_EXTRACTOR = "document-extractor"
UNKNOWN = "unknown"
NOTES = "notes"
-BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
+BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
class ComparisonOperator(StrEnum):
diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py
index fe38fafb..e1b84f0c 100644
--- a/api/app/core/workflow/nodes/http_request/config.py
+++ b/api/app/core/workflow/nodes/http_request/config.py
@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
)
-class HttpErrorDefaultTamplete(BaseModel):
+class HttpErrorDefaultTemplate(BaseModel):
body: str = Field(
default="",
description="Default body returned on HTTP error",
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'",
)
- default: HttpErrorDefaultTamplete | None = Field(
+ default: HttpErrorDefaultTemplate | None = Field(
default=None,
description="Default response template for error handling",
)
diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py
index 23378c83..086bee4a 100644
--- a/api/app/core/workflow/nodes/http_request/node.py
+++ b/api/app/core/workflow/nodes/http_request/node.py
@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
-from app.core.workflow.utils.file_processer import mime_to_file_type
+from app.core.workflow.utils.file_processor import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled.
"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py
index 5d2bdf9a..ec46b20b 100644
--- a/api/app/core/workflow/nodes/if_else/node.py
+++ b/api/app/core/workflow/nodes/if_else/node.py
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py
index e13709d4..abf21524 100644
--- a/api/app/core/workflow/nodes/jinja_render/node.py
+++ b/api/app/core/workflow/nodes/jinja_render/node.py
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py
index d3e9efd9..92699cb4 100644
--- a/api/app/core/workflow/nodes/knowledge/node.py
+++ b/api/app/core/workflow/nodes/knowledge/node.py
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py
index b293d1f4..a691001f 100644
--- a/api/app/core/workflow/nodes/llm/node.py
+++ b/api/app/core/workflow/nodes/llm/node.py
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息(AIMessage)
"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None
self.messages = []
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages
-
if messages_config:
# 使用 LangChain 消息格式
messages = []
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
- user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
"content": await self.process_message(
model_info,
content,
- user_id,
self.typed_config.vision,
)
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
- "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, self.typed_config.vision)
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
- "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, self.typed_config.vision)
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
- "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, self.typed_config.vision)
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
- content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
+ content = await self.process_message(model_info, file.value, self.typed_config.vision)
if content:
file_content.extend(content)
if messages and messages[-1]["role"] == 'user':
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
- content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
+ content = await self.process_message(model_info, file, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
message["content"] = await self.process_message(
model_info,
message["content"],
- user_id,
self.typed_config.vision
)
history_message.append(message)
diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py
index 1d42e82e..73c52b79 100644
--- a/api/app/core/workflow/nodes/memory/node.py
+++ b/api/app/core/workflow/nodes/memory/node.py
@@ -1,3 +1,4 @@
+import re
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
@@ -5,14 +6,16 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType
+from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read
+from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
class MemoryReadNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
@@ -36,19 +39,32 @@ class MemoryReadNode(BaseNode):
search_switch=self.typed_config.search_switch,
history=[],
db=db,
- storage_type="neo4j",
- user_rag_memory_id=""
+ storage_type=state["memory_storage_type"],
+ user_rag_memory_id=state["user_rag_memory_id"]
)
class MemoryWriteNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
+ @staticmethod
+ def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
+ variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
+ variable_pattern = re.compile(variable_pattern_string)
+ variables = variable_pattern.findall(content)
+ file_variables = []
+ for variable in variables:
+ if variable_pool.is_file_variable(variable):
+ file_variables.append(variable)
+ for var in file_variables:
+ content = content.replace(var, "")
+ return file_variables, content
+
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", variable_pool)
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
})
for message in self.typed_config.messages:
+ file_variables, content = self._extract_multimodal_memory_variables(
+ message.content,
+ variable_pool
+ )
+ file_info = []
+ for var in file_variables:
+ instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
+ if isinstance(instence, FileVariable):
+ file_info.append(FileInput(
+ type=instence.value.type,
+ transfer_method=instence.value.transfer_method,
+ upload_file_id=instence.value.file_id,
+ url=instence.value.url,
+ file_type=instence.value.origin_file_type
+ ).model_dump())
+ elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
+ for file_instence in instence.value:
+ file_info.append(FileInput(
+ type=file_instence.value.type,
+ transfer_method=file_instence.value.transfer_method,
+ upload_file_id=file_instence.value.file_id,
+ url=file_instence.value.url,
+ file_type=file_instence.value.origin_file_type
+ ).model_dump())
messages.append({
"role": message.role,
- "content": self._render_template(message.content, variable_pool)
+ "content": self._render_template(content, variable_pool),
+ "files": file_info
})
write_message_task.delay(
- end_user_id,
- messages,
- str(self.typed_config.config_id),
- "neo4j",
- ""
+ end_user_id=end_user_id,
+ message=messages,
+ config_id=str(self.typed_config.config_id),
+ storage_type=state["memory_storage_type"],
+ user_rag_memory_id=state["user_rag_memory_id"]
)
return "success"
diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py
index 864e3251..49add867 100644
--- a/api/app/core/workflow/nodes/node_factory.py
+++ b/api/app/core/workflow/nodes/node_factory.py
@@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
+from app.core.workflow.nodes.document_extractor import DocExtractorNode
logger = logging.getLogger(__name__)
@@ -49,7 +50,8 @@ WorkflowNode = Union[
ToolNode,
MemoryReadNode,
MemoryWriteNode,
- CodeNode
+ CodeNode,
+ DocExtractorNode
]
@@ -81,6 +83,7 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode,
+ NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
}
@classmethod
@@ -104,13 +107,15 @@ class NodeFactory:
def create_node(
cls,
node_config: dict[str, Any],
- workflow_config: dict[str, Any]
+ workflow_config: dict[str, Any],
+ down_stream_nodes: list[str]
) -> WorkflowNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
+ down_stream_nodes: 下游节点
Returns:
节点实例或 None(对于不支持的节点类型)
@@ -127,7 +132,7 @@ class NodeFactory:
# 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
- return node_class(node_config, workflow_config)
+ return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod
def get_supported_types(cls) -> list[str]:
diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py
index acac09e4..3dc5fcc3 100644
--- a/api/app/core/workflow/nodes/parameter_extractor/node.py
+++ b/api/app/core/workflow/nodes/parameter_extractor/node.py
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py
index 5cebd886..31fadaf6 100644
--- a/api/app/core/workflow/nodes/question_classifier/node.py
+++ b/api/app/core/workflow/nodes/question_classifier/node.py
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}
diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py
index a9618f7b..7a324cc4 100644
--- a/api/app/core/workflow/nodes/start/node.py
+++ b/api/app/core/workflow/nodes/start/node.py
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- """初始化 Start 节点
-
- Args:
- node_config: 节点配置
- workflow_config: 工作流配置
- """
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
# 解析并验证配置
self.typed_config: StartNodeConfig | None = None
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典
"""
self.typed_config = StartNodeConfig(**self.config)
- logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(variable_pool)
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
**custom_vars # 自定义变量作为节点输出的一部分
}
- logger.info(
- f"节点 {self.node_id} (Start) 执行完成,"
- f"输出了 {len(custom_vars)} 个自定义变量"
+ logger.debug(
+ f"Node {self.node_id} (Start) execution completed, "
+ f"outputting {len(custom_vars)} custom variables"
)
return result
diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py
index 0e9d3c62..72c5c6a8 100644
--- a/api/app/core/workflow/nodes/tool/node.py
+++ b/api/app/core/workflow/nodes/tool/node.py
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode):
"""工具节点"""
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py
index de82f8ff..9a9c5566 100644
--- a/api/app/core/workflow/nodes/variable_aggregator/node.py
+++ b/api/app/core/workflow/nodes/variable_aggregator/node.py
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
- def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
- super().__init__(node_config, workflow_config)
+ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
+ super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
diff --git a/api/app/core/workflow/utils/file_processer.py b/api/app/core/workflow/utils/file_processor.py
similarity index 100%
rename from api/app/core/workflow/utils/file_processer.py
rename to api/app/core/workflow/utils/file_processor.py
diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py
index 424fdf20..6a73efc4 100644
--- a/api/app/core/workflow/utils/template_renderer.py
+++ b/api/app/core/workflow/utils/template_renderer.py
@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式)
-_default_renderer = TemplateRenderer(strict=True)
+_strict_renderer = TemplateRenderer(strict=True)
+_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
@@ -184,7 +185,7 @@ def render_template(
... )
'请分析: 这是一段文本'
"""
- renderer = TemplateRenderer(strict=strict)
+ renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns:
错误列表
"""
- return _default_renderer.validate(template)
+ return _strict_renderer.validate(template)
diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py
index fe4aea19..0ad74865 100644
--- a/api/app/core/workflow/validator.py
+++ b/api/app/core/workflow/validator.py
@@ -6,6 +6,7 @@
import copy
import logging
+from collections import defaultdict, deque
from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType
@@ -119,7 +120,6 @@ class WorkflowValidator:
errors = []
graphs = cls.get_subgraph(workflow_config)
- logger.info(graphs)
for index, graph in enumerate(graphs):
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
- f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
+ f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
@@ -204,18 +204,18 @@ class WorkflowValidator:
Returns:
可达节点 ID 集合
"""
+ adj = defaultdict(list)
+ for edge in edges:
+ adj[edge["source"]].append(edge["target"])
+
reachable = {start_id}
- queue = [start_id]
-
+ queue = deque([start_id])
while queue:
- current = queue.pop(0)
- for edge in edges:
- if edge.get("source") == current:
- target = edge.get("target")
- if target and target not in reachable:
- reachable.add(target)
- queue.append(target)
-
+ current = queue.popleft()
+ for target in adj[current]:
+ if target not in reachable:
+ reachable.add(target)
+ queue.append(target)
return reachable
@staticmethod
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
- # 排除 loop 类型的节点
- loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
-
- # 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error":
continue
- # 如果涉及 loop 节点,跳过
- if source in loop_nodes or target in loop_nodes:
- continue
-
if source and target:
if source not in graph:
graph[source] = []
diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py
index 5e8e3f1e..79e023c1 100644
--- a/api/app/core/workflow/variable/variable_objects.py
+++ b/api/app/core/workflow/variable/variable_objects.py
@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
def valid_value(self, value) -> dict:
if not isinstance(value, dict):
- raise TypeError(f"Value must be a dict - {type(value)}:{value}")
+ raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value
def to_literal(self) -> str:
diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py
index 1095a386..616f7f3a 100644
--- a/api/app/models/memory_config_model.py
+++ b/api/app/models/memory_config_model.py
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
+ vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
+ audio_id = Column(String, nullable=True, comment="语音模型配置ID")
+ video_id = Column(String, nullable=True, comment="视频模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py
index 23fafcef..69bedc3d 100644
--- a/api/app/models/models_model.py
+++ b/api/app/models/models_model.py
@@ -2,10 +2,11 @@ import datetime
import uuid
from enum import StrEnum
-from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
-from sqlalchemy.dialects.postgresql import UUID, JSON
+from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
+from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
+
from app.db import Base
@@ -26,9 +27,9 @@ class ModelType(StrEnum):
RERANK = "rerank"
# TTS = "tts"
# SPEECH2TEXT = "speech2text"
- # IMAGE = "image"
+ IMAGE = "image"
# AUDIO = "audio"
- # VISION = "vision"
+ VIDEO = "video"
class ModelProvider(StrEnum):
@@ -45,6 +46,7 @@ class ModelProvider(StrEnum):
XINFERENCE = "xinference"
GPUSTACK = "gpustack"
BEDROCK = "bedrock"
+ VOLCANO = "volcano"
COMPOSITE = "composite"
diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py
index 044857d2..8f101eb5 100644
--- a/api/app/models/tenant_model.py
+++ b/api/app/models/tenant_model.py
@@ -23,6 +23,21 @@ class Tenants(Base):
# 国际化语言配置字段
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
+
+ # 租户联系信息
+ contact_name = Column(String(100), nullable=True) # 联系人姓名
+ contact_email = Column(String(255), nullable=True) # 联系人邮箱
+ contact_phone = Column(String(50), nullable=True) # 联系人电话
+
+ # 租户套餐信息
+ plan = Column(String(50), nullable=True) # 套餐类型
+ plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
+ api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
+ status = Column(String(50), nullable=True, default='active') # 租户状态
+
+ # 租户功能开关字段
+ feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单")
+ feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单")
# Relationship to users - one tenant has many users
users = relationship("User", back_populates="tenant")
diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py
index b6de28ec..81319789 100644
--- a/api/app/models/user_model.py
+++ b/api/app/models/user_model.py
@@ -9,7 +9,7 @@ class User(Base):
__tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
- username = Column(String, unique=True, index=True, nullable=False)
+ username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一
email = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py
index bcb3b622..6d74bcaf 100644
--- a/api/app/repositories/home_page_repository.py
+++ b/api/app/repositories/home_page_repository.py
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from uuid import UUID
-from typing import Dict
+from typing import Dict, Optional, Any
from app.models.end_user_model import EndUser
from app.models.user_model import User
@@ -190,4 +190,63 @@ class HomePageRepository:
user_count_dict = {workspace_id: count for workspace_id, count in user_counts}
- return workspaces, app_count_dict, user_count_dict
\ No newline at end of file
+ return workspaces, app_count_dict, user_count_dict
+
+ @staticmethod
+ def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]:
+ """
+ 从数据库获取版本说明(优先读取已发布的版本)
+ 使用反射方式读取表结构,不依赖 premium 模型类
+
+ Args:
+ db: 数据库会话
+ version: 版本号,如 "v0.2.7"
+
+ Returns:
+ 版本说明字典,格式与 version_info.json 一致
+ 如果数据库中没有该版本,返回 None
+ """
+ try:
+ from sqlalchemy import Table, MetaData
+
+ metadata = MetaData()
+ version_notes = Table('version_notes', metadata, autoload_with=db.engine)
+ version_note_items = Table('version_note_items', metadata, autoload_with=db.engine)
+
+ note = db.query(version_notes).filter(
+ version_notes.c.version == version,
+ version_notes.c.is_published == True
+ ).first()
+
+ if not note:
+ return None
+
+ items = db.query(version_note_items).filter(
+ version_note_items.c.note_id == note.id
+ ).order_by(version_note_items.c.sort_order).all()
+
+ core_upgrades = []
+ for item in items:
+ title = item.title
+ content = item.content
+ if content:
+ core_upgrades.append(f"{title}
{content}")
+ else:
+ core_upgrades.append(title)
+
+ return {
+ "introduction": {
+ "codeName": "",
+ "releaseDate": note.release_date.isoformat() if note.release_date else "",
+ "upgradePosition": "",
+ "coreUpgrades": core_upgrades
+ },
+ "introduction_en": {
+ "codeName": "",
+ "releaseDate": note.release_date.isoformat() if note.release_date else "",
+ "upgradePosition": "",
+ "coreUpgrades": core_upgrades
+ }
+ }
+ except Exception:
+ return None
\ No newline at end of file
diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py
index 22f13449..e64d19a3 100644
--- a/api/app/repositories/memory_config_repository.py
+++ b/api/app/repositories/memory_config_repository.py
@@ -9,21 +9,22 @@ Classes:
"""
import uuid
-from uuid import UUID
from typing import Dict, List, Optional, Tuple
+from uuid import UUID
+
+from sqlalchemy import desc, select
+from sqlalchemy.orm import Session
+
from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.memory_config_model import MemoryConfig
+from app.models.workspace_model import Workspace
from app.schemas.memory_storage_schema import (
- ConfigKey,
ConfigParamsCreate,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
-from sqlalchemy import desc, select
-from sqlalchemy.orm import Session
-
from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
return memory_config_obj
@staticmethod
- def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
+ def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
Args:
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
Returns:
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
-
- Raises:
- ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try:
- db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
+ stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
+ db_config = db.execute(stmt).scalar_one_or_none()
if not db_config:
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
- # 更新字段映射
- field_mapping = {
- # 模型选择
- "llm_id": "llm_id",
- "embedding_id": "embedding_id",
- "rerank_id": "rerank_id",
- # 记忆萃取引擎
- "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
- "enable_llm_disambiguation": "enable_llm_disambiguation",
- "deep_retrieval": "deep_retrieval",
- "t_type_strict": "t_type_strict",
- "t_name_strict": "t_name_strict",
- "t_overall": "t_overall",
- "state": "state",
- "chunker_strategy": "chunker_strategy",
- # 句子提取
- "statement_granularity": "statement_granularity",
- "include_dialogue_context": "include_dialogue_context",
- "max_context": "max_context",
- # 剪枝配置
- "pruning_enabled": "pruning_enabled",
- "pruning_scene": "pruning_scene",
- "pruning_threshold": "pruning_threshold",
- # 自我反思配置
- "enable_self_reflexion": "enable_self_reflexion",
- "iteration_period": "iteration_period",
- "reflexion_range": "reflexion_range",
- "baseline": "baseline",
- }
+ update_data = update.model_dump(exclude_unset=True)
+ update_data.pop("config_id", None)
- has_update = False
- for api_field, db_field in field_mapping.items():
- value = getattr(update, api_field, None)
- if value is not None:
- setattr(db_config, db_field, value)
- has_update = True
-
- if not has_update:
- raise ValueError("No fields to update")
+ for field, value in update_data.items():
+ setattr(db_config, field, value)
db.commit()
db.refresh(db_config)
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
"llm_id": db_config.llm_id,
"embedding_id": db_config.embedding_id,
"rerank_id": db_config.rerank_id,
+ "vision_id": db_config.vision_id,
+ "audio_id": db_config.audio_id,
+ "video_id": db_config.video_id,
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
"deep_retrieval": db_config.deep_retrieval,
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
raise
@staticmethod
- def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
+ def get_config_with_workspace(
+ db: Session,
+ config_id: uuid.UUID | int | str
+ ) -> Optional[tuple[MemoryConfig, Workspace]]:
"""Get memory config and its associated workspace information
Args:
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
"""
import time
- from app.models.workspace_model import Workspace
-
start_time = time.time()
config_id = resolve_config_id(config_id, db)
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
- return (config, workspace)
+ return config, workspace
except ValueError:
# Re-raise known business exceptions
@@ -666,7 +635,7 @@ class MemoryConfigRepository:
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
"""
from app.models.ontology_scene import OntologyScene
-
+
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
@@ -730,7 +699,7 @@ class MemoryConfigRepository:
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
"""
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
-
+
try:
# 优先查找显式标记为默认的配置
stmt = (
@@ -742,13 +711,13 @@ class MemoryConfigRepository:
)
.limit(1)
)
-
+
config = db.scalars(stmt).first()
-
+
if config:
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
return config
-
+
# 回退:获取最早创建的活跃配置
stmt = (
select(MemoryConfig)
@@ -759,25 +728,25 @@ class MemoryConfigRepository:
.order_by(MemoryConfig.created_at.asc())
.limit(1)
)
-
+
config = db.scalars(stmt).first()
-
+
if config:
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
else:
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
-
+
return config
-
+
except Exception as e:
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
raise
@staticmethod
def get_with_fallback(
- db: Session,
- config_id: Optional[uuid.UUID],
- workspace_id: uuid.UUID
+ db: Session,
+ config_id: Optional[uuid.UUID],
+ workspace_id: uuid.UUID
) -> Optional[MemoryConfig]:
"""获取记忆配置,支持回退到工作空间默认配置
@@ -792,19 +761,18 @@ class MemoryConfigRepository:
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
"""
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
-
+
if not config_id:
db_logger.debug("config_id 为空,使用工作空间默认配置")
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
-
+
config = db.get(MemoryConfig, config_id)
-
+
if config:
return config
-
+
db_logger.warning(
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
)
-
- return MemoryConfigRepository.get_workspace_default(db, workspace_id)
+ return MemoryConfigRepository.get_workspace_default(db, workspace_id)
diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py
index f49227d3..8c477d39 100644
--- a/api/app/repositories/model_repository.py
+++ b/api/app/repositories/model_repository.py
@@ -1,14 +1,15 @@
-from sqlalchemy.orm import Session, joinedload, selectinload
-from sqlalchemy import and_, or_, func, desc, select
-from typing import List, Optional, Dict, Any, Tuple
import uuid
+from typing import List, Optional, Dict, Any, Tuple
+from sqlalchemy import and_, or_, func, desc
+from sqlalchemy.orm import Session, joinedload
+
+from app.core.logging_config import get_db_logger
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
from app.schemas.model_schema import (
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery, ModelConfigQueryNew
)
-from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
@@ -137,6 +138,9 @@ class ModelConfigRepository:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
+ if query.capability:
+ filters.append(ModelConfig.capability.contains(query.capability))
+
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
@@ -435,7 +439,6 @@ class ModelConfigRepository:
ModelConfig.is_public
),
ModelConfig.provider == provider,
- ModelConfig.is_active,
~ModelConfig.is_composite
)
).all()
diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py
index 42c178b3..1939a062 100644
--- a/api/app/repositories/neo4j/add_nodes.py
+++ b/api/app/repositories/neo4j/add_nodes.py
@@ -1,17 +1,22 @@
+import logging
from typing import List, Optional
-from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
+from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
+ MEMORY_SUMMARY_NODE_SAVE
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
+logger = logging.getLogger(__name__)
+
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
- print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
+ logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
+
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database.
@@ -23,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
List of created node UUIDs or None if failed
"""
if not dialogues:
- print("No dialogues to save")
+ logger.info("No dialogues to save")
return []
try:
@@ -48,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
)
created_uuids = [record["uuid"] for record in result]
- print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
+ logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids
except Exception as e:
- print(f"Error creating dialogue nodes: {e}")
+ logger.error(f"Error creating dialogue nodes: {e}")
return None
@@ -67,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
List of created node UUIDs or None if failed
"""
if not statements:
- print("No statements to save")
+ logger.info("No statements to save")
return []
try:
@@ -120,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
)
created_uuids = [record["uuid"] for record in result]
- print(f"Successfully created {len(created_uuids)} statement nodes")
+ logger.info(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids
except Exception as e:
- print(f"Error creating statement nodes: {e}")
+ logger.error(f"Error creating statement nodes: {e}")
return None
+
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch.
@@ -138,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
List of created chunk UUIDs or None if failed
"""
if not chunks:
- print("No chunk nodes to add")
+ logger.info("No chunk nodes to add")
return []
try:
@@ -171,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
)
created_uuids = [record["uuid"] for record in result]
- print(f"Successfully created {len(created_uuids)} chunk nodes")
+ logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids
except Exception as e:
- print(f"Error creating chunk nodes: {e}")
+ logger.error(f"Error creating chunk nodes: {e}")
return None
-
-async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
+async def add_memory_summary_nodes(
+ summaries: List[MemorySummaryNode],
+ connector: Neo4jConnector
+) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch.
Args:
@@ -191,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
List of created summary node ids or None if failed
"""
if not summaries:
- print("No memory summary nodes to add")
+ logger.info("No memory summary nodes to add")
return []
try:
@@ -211,16 +219,14 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})
-
+
result = await connector.execute_query(
MEMORY_SUMMARY_NODE_SAVE,
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
- print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
+ logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception as e:
- print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
+ logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None
-
-
diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py
index 7273340e..bd448c99 100644
--- a/api/app/repositories/neo4j/community_repository.py
+++ b/api/app/repositories/neo4j/community_repository.py
@@ -300,7 +300,7 @@ class CommunityRepository:
)
return bool(result)
except Exception as e:
- logger.error(f"update_community_metadata failed: {e}")
+ logger.error(f"update_community_metadata failed: {e}", exc_info=True)
return False
async def batch_update_community_metadata(
diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py
index 0cdaeb59..1f699ad8 100644
--- a/api/app/repositories/neo4j/cypher_queries.py
+++ b/api/app/repositories/neo4j/cypher_queries.py
@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
RETURN elementId(r) AS uuid
"""
-
# Entity Merge Query
MERGE_ENTITIES = """
MATCH (canonical:ExtractedEntity {id: $canonical_id})
@@ -829,9 +828,8 @@ neo4j_query_all = """
other as entity2
"""
-
'''针对当前节点下扩长的句子,实体和总结'''
-Memory_Timeline_ExtractedEntity="""
+Memory_Timeline_ExtractedEntity = """
MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) = $id
AND (ms:ExtractedEntity OR ms:MemorySummary)
@@ -869,7 +867,7 @@ RETURN
"""
-Memory_Timeline_MemorySummary="""
+Memory_Timeline_MemorySummary = """
MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) =$id
AND (ms:MemorySummary OR ms:ExtractedEntity)
@@ -904,7 +902,7 @@ RETURN
}
) AS statement;
"""
-Memory_Timeline_Statement="""
+Memory_Timeline_Statement = """
MATCH (n)
WHERE elementId(n) = $id
@@ -947,7 +945,7 @@ RETURN
"""
'''针对当前节点,主要获取更加完整的句子节点'''
-Memory_Space_Emotion_Statement="""
+Memory_Space_Emotion_Statement = """
MATCH (n)
WHERE elementId(n) = $id
RETURN
@@ -957,7 +955,7 @@ RETURN
n.statement AS statement;
"""
-Memory_Space_Emotion_MemorySummary="""
+Memory_Space_Emotion_MemorySummary = """
MATCH (n)-[]-(e)
WHERE elementId(n) = $id
AND EXISTS {
@@ -970,7 +968,7 @@ RETURN DISTINCT
e.emotion_type AS emotion_type,
e.statement AS statement;
"""
-Memory_Space_Emotion_ExtractedEntity="""
+Memory_Space_Emotion_ExtractedEntity = """
MATCH (n)-[]-(e)
WHERE elementId(n) = $id
AND EXISTS {
@@ -985,18 +983,18 @@ RETURN DISTINCT
'''获取实体'''
-Memory_Space_User="""
+Memory_Space_User = """
MATCH (n)-[r]->(m)
WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id
"""
-Memory_Space_Entity="""
+Memory_Space_Entity = """
MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN
DISTINCT m.name as name,m.end_user_id as end_user_id
"""
-Memory_Space_Associative="""
+Memory_Space_Associative = """
MATCH (u)-[]-(x)-[]-(h)
WHERE elementId(u) = $user_id
AND elementId(h) = $id
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
"""
Graph_Node_query = """
- MATCH (n:MemorySummary)
- WHERE n.end_user_id = $end_user_id
- RETURN
- elementId(n) AS id,
- labels(n) AS labels,
- properties(n) AS properties,
- 0 AS priority
- LIMIT $limit
+MATCH (n:MemorySummary)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 0 AS priority
+LIMIT $limit
- UNION ALL
+UNION ALL
- MATCH (n:Dialogue)
- WHERE n.end_user_id = $end_user_id
- RETURN
- elementId(n) AS id,
- labels(n) AS labels,
- properties(n) AS properties,
- 1 AS priority
- LIMIT 1
+MATCH (n:Dialogue)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 1 AS priority
+LIMIT 1
- UNION ALL
+UNION ALL
- MATCH (n:Statement)
- WHERE n.end_user_id = $end_user_id
- RETURN
- elementId(n) AS id,
- labels(n) AS labels,
- properties(n) AS properties,
- 1 AS priority
- LIMIT $limit
+MATCH (n:Statement)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 1 AS priority
+LIMIT $limit
- UNION ALL
+UNION ALL
- MATCH (n:ExtractedEntity)
- WHERE n.end_user_id = $end_user_id
- RETURN
- elementId(n) AS id,
- labels(n) AS labels,
- properties(n) AS properties,
- 2 AS priority
- LIMIT $limit
+MATCH (n:ExtractedEntity)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 2 AS priority
+LIMIT $limit
- UNION ALL
+UNION ALL
- MATCH (n:Chunk)
- WHERE n.end_user_id = $end_user_id
- RETURN
- elementId(n) AS id,
- labels(n) AS labels,
- properties(n) AS properties,
- 3 AS priority
- LIMIT $limit
+MATCH (n:Chunk)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 3 AS priority
+LIMIT $limit
- """
+UNION ALL
+MATCH (n:Perceptual)
+WHERE n.end_user_id = $end_user_id
+RETURN
+ elementId(n) AS id,
+ labels(n) AS labels,
+ properties(n) AS properties,
+ 4 AS priority
+"""
# ============================================================
# Community 节点 & BELONGS_TO_COMMUNITY 边
@@ -1069,6 +1075,7 @@ Graph_Node_query = """
COMMUNITY_NODE_UPSERT = """
MERGE (c:Community {community_id: $community_id})
+ON CREATE SET c.id = $community_id
SET c.end_user_id = $end_user_id,
c.member_count = $member_count,
c.updated_at = datetime()
@@ -1175,7 +1182,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
-SET c.name = $name,
+SET c.id = coalesce(c.id, $community_id),
+ c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.summary_embedding = $summary_embedding,
@@ -1186,7 +1194,8 @@ RETURN c.community_id AS community_id
BATCH_UPDATE_COMMUNITY_METADATA = """
UNWIND $communities AS row
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
-SET c.name = row.name,
+SET c.id = coalesce(c.id, row.community_id),
+ c.name = row.name,
c.summary = row.summary,
c.core_entities = row.core_entities,
c.summary_embedding = row.summary_embedding,
@@ -1270,6 +1279,40 @@ RETURN
startNode(r) = e AS r_from_e
"""
+CHECK_COMMUNITY_IS_COMPLETE = """
+MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
+RETURN (
+ c.name IS NOT NULL AND c.name <> '' AND
+ c.summary IS NOT NULL AND c.summary <> '' AND
+ c.core_entities IS NOT NULL
+) AS is_complete
+"""
+
+CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
+MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
+RETURN (
+ c.name IS NOT NULL AND c.name <> '' AND
+ c.summary IS NOT NULL AND c.summary <> '' AND
+ c.core_entities IS NOT NULL AND
+ c.summary_embedding IS NOT NULL
+) AS is_complete
+"""
+
+GET_INCOMPLETE_COMMUNITIES = """
+MATCH (c:Community {end_user_id: $end_user_id})
+WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
+ OR c.name = '' OR c.summary = ''
+RETURN c.community_id AS community_id
+"""
+
+GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
+MATCH (c:Community {end_user_id: $end_user_id})
+WHERE c.name IS NULL OR c.name = ''
+ OR c.summary IS NULL OR c.summary = ''
+ OR c.core_entities IS NULL
+ OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
+RETURN c.community_id AS community_id
+"""
# Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """
@@ -1327,37 +1370,35 @@ ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
"""
-CHECK_COMMUNITY_IS_COMPLETE = """
-MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
-RETURN (
- c.name IS NOT NULL AND c.name <> '' AND
- c.summary IS NOT NULL AND c.summary <> '' AND
- c.core_entities IS NOT NULL
-) AS is_complete
+# 感知记忆节点保存
+PERCEPTUAL_NODE_SAVE = """
+UNWIND $perceptuals AS p
+MERGE (n:Perceptual {id: p.id})
+SET n += {
+ id: p.id,
+ end_user_id: p.end_user_id,
+ perceptual_type: p.perceptual_type,
+ file_path: p.file_path,
+ file_name: p.file_name,
+ file_ext: p.file_ext,
+ summary: p.summary,
+ keywords: p.keywords,
+ topic: p.topic,
+ domain: p.domain,
+ created_at: p.created_at,
+ file_type: p.file_type,
+ summary_embedding: p.summary_embedding
+}
+RETURN n.id AS uuid
"""
-CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
-MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
-RETURN (
- c.name IS NOT NULL AND c.name <> '' AND
- c.summary IS NOT NULL AND c.summary <> '' AND
- c.core_entities IS NOT NULL AND
- c.summary_embedding IS NOT NULL
-) AS is_complete
-"""
-
-GET_INCOMPLETE_COMMUNITIES = """
-MATCH (c:Community {end_user_id: $end_user_id})
-WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
- OR c.name = '' OR c.summary = ''
-RETURN c.community_id AS community_id
-"""
-
-GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
-MATCH (c:Community {end_user_id: $end_user_id})
-WHERE c.name IS NULL OR c.name = ''
- OR c.summary IS NULL OR c.summary = ''
- OR c.core_entities IS NULL
- OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
-RETURN c.community_id AS community_id
+# 感知记忆与对话的关联边
+PERCEPTUAL_CHUNK_EDGE_SAVE = """
+UNWIND $edges AS edge
+MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
+MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
+MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
+ON CREATE SET r.end_user_id = edge.end_user_id,
+ r.created_at = edge.created_at
+RETURN elementId(r) AS uuid
"""
diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py
index 34497d5b..adc266fe 100644
--- a/api/app/repositories/neo4j/graph_saver.py
+++ b/api/app/repositories/neo4j/graph_saver.py
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
StatementNode,
ExtractedEntityNode,
EntityEntityEdge,
+ PerceptualNode,
+ PerceptualEdge,
)
import logging
+
logger = logging.getLogger(__name__)
+
+
async def save_entities_and_relationships(
- entity_nodes: List[ExtractedEntityNode],
- entity_entity_edges: List[EntityEntityEdge],
- connector: Neo4jConnector
+ entity_nodes: List[ExtractedEntityNode],
+ entity_entity_edges: List[EntityEntityEdge],
+ connector: Neo4jConnector
):
"""Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes]
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
async def save_chunk_nodes(
- chunk_nodes: List[ChunkNode],
- connector: Neo4jConnector
+ chunk_nodes: List[ChunkNode],
+ connector: Neo4jConnector
):
"""Save chunk nodes using graph models"""
if not chunk_nodes:
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
async def save_statement_chunk_edges(
- statement_chunk_edges: List[StatementChunkEdge],
- connector: Neo4jConnector
+ statement_chunk_edges: List[StatementChunkEdge],
+ connector: Neo4jConnector
):
"""Save statement-chunk edges using graph models"""
if not statement_chunk_edges:
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
async def save_statement_entity_edges(
- statement_entity_edges: List[StatementEntityEdge],
- connector: Neo4jConnector
+ statement_entity_edges: List[StatementEntityEdge],
+ connector: Neo4jConnector
):
"""Save statement-entity edges using graph models"""
if not statement_entity_edges:
@@ -142,7 +147,7 @@ async def save_statement_entity_edges(
if all_se_edges:
try:
await connector.execute_query(
- STATEMENT_ENTITY_EDGE_SAVE,
+ STATEMENT_ENTITY_EDGE_SAVE,
relationships=all_se_edges
)
except Exception:
@@ -154,24 +159,28 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
+ perceptual_nodes: List[PerceptualNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
+ perceptual_edges: List[PerceptualEdge],
connector: Neo4jConnector,
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
- schedule_clustering_after_write() 显式触发。
+ _trigger_clustering_sync() 显式触发。
Args:
dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save
+ perceptual_nodes: List of PerceptualNode objects to save
entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save
+ perceptual_edges: List of PerceptualEdge objects to save
connector: Neo4j connector instance
Returns:
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
- print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
+ logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
# 2. Save all chunk nodes in batch
if chunk_nodes:
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
+ if perceptual_nodes:
+ from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
+ perceptual_data = [node.model_dump() for node in perceptual_nodes]
+ result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
+ perceptual_uuids = [record["uuid"] async for record in result]
+ results["perceptuals"] = perceptual_uuids
+ logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
+
# 3. Save all statement nodes in batch
if statement_nodes:
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
+ if perceptual_edges:
+ from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
+ perceptual_edge_data = []
+ for edge in perceptual_edges:
+ print(edge.source, edge.target)
+ perceptual_edge_data.append({
+ "perceptual_id": edge.source,
+ "chunk_id": edge.target,
+ "end_user_id": edge.end_user_id,
+ "created_at": edge.created_at.isoformat() if edge.created_at else None,
+ })
+ result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
+ perceptual_edges_uuids = [record["uuid"] async for record in result]
+ results['perceptual_chunk_edges'] = perceptual_edges_uuids
+ logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
+
return results
try:
@@ -303,16 +336,13 @@ async def save_dialog_and_statements_to_neo4j(
return False
-def schedule_clustering_after_write(
- entity_nodes: List,
- llm_model_id: Optional[str] = None,
- embedding_model_id: Optional[str] = None,
+async def _trigger_clustering_sync(
+ entity_nodes: List,
+ llm_model_id: Optional[str] = None,
+ embedding_model_id: Optional[str] = None,
) -> None:
"""
- 写入 Neo4j 成功后,调度后台聚类任务。
-
- 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
- 使用 asyncio.create_task 异步触发,不阻塞写入响应。
+ 同步等待聚类完成,避免与其他 LLM 任务并发冲突。
"""
if not entity_nodes:
return
@@ -324,15 +354,16 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
- logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
- asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
+ logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
+ await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
+ embedding_model_id=embedding_model_id)
async def _trigger_clustering(
- new_entity_ids: List[str],
- end_user_id: str,
- llm_model_id: Optional[str] = None,
- embedding_model_id: Optional[str] = None,
+ new_entity_ids: List[str],
+ end_user_id: str,
+ llm_model_id: Optional[str] = None,
+ embedding_model_id: Optional[str] = None,
) -> None:
"""
聚类触发函数,自动判断全量初始化还是增量更新。
diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py
index 1582d862..e34945eb 100644
--- a/api/app/schemas/app_schema.py
+++ b/api/app/schemas/app_schema.py
@@ -196,6 +196,13 @@ class CitationConfig(BaseModel):
enabled: bool = Field(default=False)
+class Citation(BaseModel):
+ document_id: str
+ file_name: str
+ knowledge_id: str
+ score: float
+
+
class WebSearchConfig(BaseModel):
"""联网搜索配置"""
enabled: bool = Field(default=False)
diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py
index 8d7490fe..e186e54b 100644
--- a/api/app/schemas/memory_config_schema.py
+++ b/api/app/schemas/memory_config_schema.py
@@ -387,6 +387,12 @@ class MemoryConfig:
rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None
+ video_model_id: Optional[UUID] = None
+ video_model_name: Optional[str] = None
+ vision_model_id: Optional[UUID] = None
+ vision_model_name: Optional[str] = None
+ audio_model_id: Optional[UUID] = None
+ audio_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict)
diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py
index 046b79e7..711b6de9 100644
--- a/api/app/schemas/memory_storage_schema.py
+++ b/api/app/schemas/memory_storage_schema.py
@@ -8,9 +8,6 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
-
-
-
# ============================================================================
# 从 json_schema.py 迁移的 Schema
# ============================================================================
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file."""
- data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
+ data: List[BaseDataSchema] = Field(...,
+ description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.")
- quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
- memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
+ quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
+ description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
+ memory_verify: Optional[MemoryVerifySchema] = Field(None,
+ description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before")
def _normalize_data(cls, v):
@@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel):
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
"""
field: List[Dict[str, Any]] = Field(
- ...,
+ ...,
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
)
+
class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
- resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
- change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
+ resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
+ description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
+ change: Optional[List[ChangeRecordSchema]] = Field(None,
+ description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel):
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.")
+
class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
- results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
+ results: List[SingleReflexionResultSchema] = Field(...,
+ description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before")
def _normalize_resolved(cls, v):
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
- config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
- user_id: str = Field("user_id", description="用户标识(字符串)")
- apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
+ config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
+ user_id: str | None = Field(default=None, description="用户标识(字符串)")
+ apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
# Allowed chunking strategies (extendable later)
@@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)")
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
-
+
# 本体场景关联(可选)
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
-
+
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入)
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
-
+
# 模型配置字段(可选,用于手动指定或自动填充)
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
+
+
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)")
- config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
+ config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
- config_id:Union[uuid.UUID, int, str] = None
+ config_id: Union[uuid.UUID, int, str] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
+ audio_id: Optional[str] = Field(None, description="语音模型ID")
+ vision_id: Optional[str] = Field(None, description="视觉模型ID")
+ video_id: Optional[str] = Field(None, description="视频模型ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
enable_llm_dedup_blockwise: Optional[bool] = None
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型
- config_id:Union[uuid.UUID, int, str] = None
+ config_id: Union[uuid.UUID, int, str] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
- config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
+ config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
def fail(
- msg: str,
- error_code: str = "ERROR",
- data: Optional[Any] = None,
- time: Optional[int] = None,
- query_preview: Optional[str] = None,
+ msg: str,
+ error_code: str = "ERROR",
+ data: Optional[Any] = None,
+ time: Optional[int] = None,
+ query_preview: Optional[str] = None,
) -> ApiResponse:
payload = data
if query_preview is not None:
@@ -387,12 +397,13 @@ def fail(
time=time or _now_ms(),
)
+
class GenerateCacheRequest(BaseModel):
"""缓存生成请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
end_user_id: Optional[str] = Field(
- None,
+ None,
description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成"
)
@@ -404,7 +415,7 @@ class GenerateCacheRequest(BaseModel):
class ForgettingTriggerRequest(BaseModel):
"""手动触发遗忘周期请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)")
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)")
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)")
@@ -413,7 +424,7 @@ class ForgettingTriggerRequest(BaseModel):
class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数")
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
- config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
+ config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -448,7 +459,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
class ForgettingCycleHistoryPoint(BaseModel):
"""遗忘周期历史数据点模型(用于趋势图)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
date: str = Field(..., description="日期(格式: '1/1', '1/2')")
merged_count: int = Field(..., description="每日融合节点数")
average_activation: Optional[float] = Field(None, description="平均激活值")
@@ -459,7 +470,7 @@ class ForgettingCycleHistoryPoint(BaseModel):
class PendingForgettingNode(BaseModel):
"""待遗忘节点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
node_id: str = Field(..., description="节点ID")
node_type: str = Field(..., description="节点类型:statement/entity/summary")
content_summary: str = Field(..., description="内容摘要")
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
- recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
+ recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
+ description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
timestamp: int = Field(..., description="统计时间(时间戳)")
@@ -480,7 +492,7 @@ class ForgettingStatsResponse(BaseModel):
class ForgettingReportResponse(BaseModel):
"""遗忘周期报告响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
merged_count: int = Field(..., description="融合的节点对数量")
nodes_before: int = Field(..., description="遗忘前的节点总数")
nodes_after: int = Field(..., description="遗忘后的节点总数")
@@ -495,7 +507,7 @@ class ForgettingReportResponse(BaseModel):
class ForgettingCurvePoint(BaseModel):
"""遗忘曲线数据点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
day: int = Field(..., description="天数")
activation: float = Field(..., description="激活值")
retention_rate: float = Field(..., description="保持率(与激活值相同)")
@@ -504,7 +516,7 @@ class ForgettingCurvePoint(BaseModel):
class ForgettingCurveRequest(BaseModel):
"""遗忘曲线请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
@@ -513,6 +525,6 @@ class ForgettingCurveRequest(BaseModel):
class ForgettingCurveResponse(BaseModel):
"""遗忘曲线响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
-
+
curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表")
config: Dict[str, Any] = Field(..., description="使用的配置参数")
diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py
index 058f082d..668a84a8 100644
--- a/api/app/schemas/model_schema.py
+++ b/api/app/schemas/model_schema.py
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
updated_at: datetime.datetime
api_keys: List["ModelApiKey"] = []
+ @staticmethod
+ def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
+ if not key or len(key) <= prefix + suffix:
+ return "*" * len(key)
+ return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
+
@field_validator("api_keys", mode="after")
@classmethod
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
def _serialize_created_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
+ @field_serializer("api_keys", when_used="json")
+ def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
+ result = []
+ for api_key in api_keys:
+ data = api_key.model_dump()
+ data["api_key"] = self.mask_api_key(api_key.api_key)
+ result.append(data)
+ return result
+
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase):
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
self.model_config_ids = [
mc.id for mc in self.model_configs
- if hasattr(mc, 'id')
- and not getattr(mc, 'is_composite', False)
- and getattr(mc, 'name', None) == self.model_name
+ if hasattr(mc, 'id')
+ and not getattr(mc, 'is_composite', False)
+ and getattr(mc, 'name', None) == self.model_name
]
# 情况2:字典列表
elif isinstance(self.model_configs, list):
self.model_config_ids = [
mc['id'] if isinstance(mc, dict) else mc.id
for mc in self.model_configs
- if ((isinstance(mc, dict)
- and 'id' in mc
+ if ((isinstance(mc, dict)
+ and 'id' in mc
and not mc.get('is_composite', False)
- and mc.get('name') == self.model_name) or
- (hasattr(mc, 'id')
+ and mc.get('name') == self.model_name) or
+ (hasattr(mc, 'id')
and not getattr(mc, 'is_composite', False)
and getattr(mc, 'name', None) == self.model_name))
]
@@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase):
validate_assignment=True # 确保赋值触发校验
)
-
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
-
+
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
"""模型配置查询Schema"""
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
+ capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
is_active: Optional[bool] = Field(None, description="激活状态筛选")
is_public: Optional[bool] = Field(None, description="公开状态筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
+
class ModelMarketplace(BaseModel):
"""模型广场响应Schema"""
llm_models: List[ModelConfig] = []
@@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel):
class ModelBase(BaseModel):
"""基础模型Schema"""
model_config = ConfigDict(from_attributes=True)
-
+
id: uuid.UUID
name: str
type: str
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
+
class ModelInfo(BaseModel):
"""模型信息Schema"""
model_name: str = Field(..., description="模型名称")
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
is_omni: bool = Field(default=False, description="是否为omni模型")
model_type: ModelType = Field(..., description="模型类型")
capability: List[str] = Field(default_factory=list, description="模型能力列表")
-
diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py
index 6fcf680b..3dda6fc0 100644
--- a/api/app/services/app_chat_service.py
+++ b/api/app/services/app_chat_service.py
@@ -82,6 +82,12 @@ class AppChatService:
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
+ # opening_statement:首轮对话注入开场白
+ is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
+ system_prompt = self.agent_service._inject_opening_statement(
+ features_config, system_prompt, is_new_conversation
+ )
+
# 准备工具列表
tools = []
@@ -93,7 +99,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
- tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
+ kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
+ tools.extend(kb_tools)
memory_flag = False
if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config(
@@ -129,45 +136,18 @@ class AppChatService:
)
# 加载历史消息
- messages = self.conversation_service.get_messages(
+ history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
- limit=10
+ max_history=10,
+ current_provider=api_key_obj.provider,
+ current_is_omni=api_key_obj.is_omni
)
- history = []
- for msg in messages:
- content = [{"type": "text", "text": msg.content}]
-
- # 处理 meta_data 中的 files
- if msg.meta_data and msg.meta_data.get("files"):
- files = msg.meta_data.get("files", [])
- # 使用 MultimodalService 处理文件
- multimodal_service = MultimodalService(self.db, api_config=model_info)
-
- # 将 files 转换为 FileInput 格式
- file_inputs = []
- for file in files:
- from app.schemas.app_schema import FileInput, TransferMethod
- file_input = FileInput(
- type=file.get("type"),
- transfer_method=TransferMethod.REMOTE_URL,
- url=file.get("url")
- )
- file_inputs.append(file_input)
-
- history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
-
- content.extend(history_processed_files)
-
- history.append({
- "role": msg.role,
- "content": content
- })
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, model_info)
- processed_files = await multimodal_service.process_files(user_id, files)
+ processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent(支持多模态)
@@ -206,7 +186,8 @@ class AppChatService:
# 构建用户消息内容(含多模态文件)
human_meta = {
- "files": []
+ "files": [],
+ "history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
@@ -221,6 +202,13 @@ class AppChatService:
"url": f.url
})
+ if processed_files:
+ human_meta["history_files"] = {
+ "content": processed_files,
+ "provider": api_key_obj.provider,
+ "is_omni": api_key_obj.is_omni
+ }
+
# 保存消息
if audio_url:
assistant_meta["audio_url"] = audio_url
@@ -249,8 +237,9 @@ class AppChatService:
}),
"elapsed_time": elapsed_time,
"suggested_questions": suggested_questions,
- "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
+ "citations": self.agent_service._filter_citations(features_config, citations_collector),
"audio_url": audio_url,
+ "audio_status": "pending"
}
async def agnet_chat_stream(
@@ -301,6 +290,12 @@ class AppChatService:
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
+ # opening_statement:首轮对话注入开场白
+ is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
+ system_prompt = self.agent_service._inject_opening_statement(
+ features_config, system_prompt, is_new_conversation
+ )
+
# 准备工具列表
tools = []
@@ -313,7 +308,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
- tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
+ kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
+ tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
if memory:
@@ -350,45 +346,18 @@ class AppChatService:
)
# 加载历史消息
- messages = self.conversation_service.get_messages(
+ history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
- limit=10
+ max_history=10,
+ current_provider=api_key_obj.provider,
+ current_is_omni=api_key_obj.is_omni
)
- history = []
- for msg in messages:
- content = [{"type": "text", "text": msg.content}]
-
- # 处理 meta_data 中的 files
- if msg.meta_data and msg.meta_data.get("files"):
- history_files = msg.meta_data.get("files", [])
- # 使用 MultimodalService 处理文件
- multimodal_service = MultimodalService(self.db, api_config=model_info)
-
- # 将 files 转换为 FileInput 格式
- file_inputs = []
- for file in history_files:
- from app.schemas.app_schema import FileInput, TransferMethod
- file_input = FileInput(
- type=file.get("type"),
- transfer_method=TransferMethod.REMOTE_URL,
- url=file.get("url")
- )
- file_inputs.append(file_input)
-
- history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
-
- content.extend(history_processed_files)
-
- history.append({
- "role": msg.role,
- "content": content
- })
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, model_info)
- processed_files = await multimodal_service.process_files(user_id, files)
+ processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent(支持多模态),同时并行启动 TTS
@@ -433,7 +402,7 @@ class AppChatService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
- # 发送结束事件(包含 suggested_questions、tts、citations)
+ # 发送结束事件(包含 suggested_questions、tts、audio_status、citations)
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
@@ -443,11 +412,23 @@ class AppChatService:
"api_base": api_key_obj.api_base}, {}
)
end_data["audio_url"] = stream_audio_url
- end_data["citations"] = self.agent_service._filter_citations(features_config, [])
+ # 检查TTS是否已完成(非阻塞,不取消任务)
+ audio_status = "pending"
+ if tts_task is not None and tts_task.done():
+ # 任务已完成,检查是否有异常
+ try:
+ tts_task.result()
+ audio_status = "completed"
+ except Exception as e:
+ logger.warning(f"TTS任务异常: {e}")
+ audio_status = "failed"
+ end_data["audio_status"] = audio_status if stream_audio_url else None
+ end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector)
# 保存消息
human_meta = {
- "files":[]
+ "files":[],
+ "history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
@@ -457,11 +438,16 @@ class AppChatService:
if files:
for f in files:
- # url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({
"type": f.type,
"url": f.url
})
+ if processed_files:
+ human_meta["history_files"] = {
+ "content": processed_files,
+ "provider": api_key_obj.provider,
+ "is_omni": api_key_obj.is_omni
+ }
if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url
diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py
index 19aaac42..4dcabff8 100644
--- a/api/app/services/app_service.py
+++ b/api/app/services/app_service.py
@@ -1638,7 +1638,7 @@ class AppService:
# ==================== 记忆配置提取方法 ====================
- def _extract_memory_config_id(
+ def _get_memory_config_id_from_release(
self,
app_type: str,
config: Dict[str, Any]
@@ -1863,7 +1863,7 @@ class AppService:
self.db.flush() # 先 flush,确保 release 已插入数据库
# 提取记忆配置ID并更新终端用户
- memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config)
+ memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(app.type, config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id:
@@ -2001,7 +2001,7 @@ class AppService:
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
# 提取记忆配置ID并更新终端用户
- memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config)
+ memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(release.type, release.config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id:
diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py
index f8a01a40..014d96b7 100644
--- a/api/app/services/conversation_service.py
+++ b/api/app/services/conversation_service.py
@@ -274,7 +274,8 @@ class ConversationService:
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None,
- api_config: Optional[ModelInfo] = None
+ current_provider: Optional[str] = None,
+ current_is_omni: Optional[bool] = None
) -> List[dict]:
"""
Retrieve historical conversation messages formatted as dictionaries.
@@ -282,7 +283,8 @@ class ConversationService:
Args:
conversation_id (uuid.UUID): Conversation UUID.
max_history (Optional[int]): Maximum number of messages to retrieve.
- api_config (Optional[ModelInfo]): Model API configuration for multimodal processing.
+ current_provider (Optional[str]): Current provider for file handling.
+ current_is_omni (Optional[bool]): Current omni flag for file handling.
Returns:
List[dict]: List of message dictionaries with keys 'role' and 'content'.
@@ -292,38 +294,30 @@ class ConversationService:
limit=max_history
)
- # 转换为字典格式
history = []
for msg in messages:
- content = [{"type": "text", "text": msg.content}]
-
- # 处理 meta_data 中的 files
- if msg.meta_data and msg.meta_data.get("files"):
- files = msg.meta_data.get("files", [])
- if api_config:
- # 使用 MultimodalService 处理文件
- from app.services.multimodal_service import MultimodalService
- multimodal_service = MultimodalService(self.db, api_config=api_config)
-
- # 将 files 转换为 FileInput 格式
- file_inputs = []
- for file in files:
- from app.schemas.app_schema import FileInput, TransferMethod
- file_input = FileInput(
- type=file.get("type"),
- transfer_method=TransferMethod.REMOTE_URL,
- url=file.get("url")
- )
- file_inputs.append(file_input)
-
- processed_files = await multimodal_service.history_process_files(files=file_inputs)
-
- content.extend(processed_files)
-
- history.append({
+ msg_dict = {
"role": msg.role,
- "content": content
- })
+ "content": [{"type": "text", "text": msg.content}]
+ }
+
+ # 处理用户消息中的多模态文件
+ if msg.role == "user" and msg.meta_data:
+ history_files = msg.meta_data.get("history_files", {})
+
+ if history_files and current_provider and current_is_omni is not None:
+ # 检查是否需要重新处理文件
+ stored_provider = history_files.get("provider")
+ stored_is_omni = history_files.get("is_omni")
+
+ # 如果provider或is_omni不匹配,需要重新处理
+ if stored_provider != current_provider or stored_is_omni != current_is_omni:
+ continue
+
+ # provider和is_omni匹配,直接使用存储的内容
+ msg_dict["content"].extend(history_files.get("content"))
+
+ history.append(msg_dict)
return history
@@ -539,6 +533,7 @@ class ConversationService:
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
+ is_omni = api_config.is_omni
model_type = config.type
llm = RedBearLLM(
@@ -546,7 +541,8 @@ class ConversationService:
model_name=model_name,
provider=provider,
api_key=api_key,
- base_url=api_base
+ base_url=api_base,
+ is_omni=is_omni
),
type=ModelType(model_type)
)
@@ -554,15 +550,8 @@ class ConversationService:
conversation_messages = await self.get_conversation_history(
conversation_id=conversation_id,
max_history=20,
- api_config=ModelInfo(
- model_name=model_name,
- provider=provider,
- api_key=api_key,
- api_base=api_base,
- capability=api_config.capability,
- is_omni=api_config.is_omni,
- model_type=model_type
- )
+ current_provider=provider,
+ current_is_omni=is_omni
)
if len(conversation_messages) == 0:
return ConversationOut(
diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py
index 5989f0f8..ac34b4de 100644
--- a/api/app/services/draft_run_service.py
+++ b/api/app/services/draft_run_service.py
@@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository
-from app.schemas.app_schema import FileInput
+from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
@@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
return web_search_tool
-def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
+def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
kb_config: 知识库配置
kb_ids: 知识库ID列表
user_id: 用户ID
+ citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充)
+ 列表元素类型为 Citation,包含字段:
+ - document_id: 文档唯一标识
+ - file_name: 文件名
+ - knowledge_id: 知识库 ID
+ - score: 检索相关性得分
Returns:
检索到的相关知识内容
@@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
}
)
+ # 收集引用信息
+ if citations_collector is not None:
+ seen_doc_ids = {c.get("document_id") for c in citations_collector}
+ for chunk in retrieve_chunks_result:
+ meta = chunk.metadata or {}
+ doc_id = meta.get("document_id") or meta.get("doc_id")
+ if doc_id and doc_id not in seen_doc_ids:
+ seen_doc_ids.add(doc_id)
+ citations_collector.append(Citation(
+ document_id=doc_id,
+ file_name=meta.get("file_name", ""),
+ knowledge_id=str(meta.get("knowledge_id", "")),
+ score=meta.get("score", 0)
+ ))
+
return f"检索到以下相关信息:\n\n{context}"
else:
logger.warning("知识库检索未找到结果")
@@ -320,26 +341,26 @@ class AgentRunService:
self,
knowledge_retrieval_config: dict | None,
user_id
- ) -> list:
+ ) -> tuple[list, list]:
+ """返回 (tools, citations_collector)"""
if not knowledge_retrieval_config:
- return []
+ return [], []
+ citations_collector = []
tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
- kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
+ kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
- # 创建知识库检索工具
- kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
+ kb_tool = create_knowledge_retrieval_tool(
+ knowledge_retrieval_config, kb_ids, user_id,
+ citations_collector=citations_collector
+ )
tools.append(kb_tool)
-
logger.debug(
"已添加知识库检索工具",
- extra={
- "kb_ids": kb_ids,
- "tool_count": len(tools)
- }
+ extra={"kb_ids": kb_ids, "tool_count": len(tools)}
)
- return tools
+ return tools, citations_collector
def load_memory_config(
self,
@@ -441,12 +462,12 @@ class AgentRunService:
@staticmethod
def _filter_citations(
features_config: Dict[str, Any],
- citations: List[Any]
+ citations: List[Citation]
) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源"""
citation_cfg = features_config.get("citation", {})
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
- return citations
+ return [cit.model_dump() for cit in citations]
return []
async def run(
@@ -549,7 +570,8 @@ class AgentRunService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
- tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
+ kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
+ tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
if memory:
@@ -592,8 +614,9 @@ class AgentRunService:
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
- api_config=model_info,
- max_history=10
+ max_history=10,
+ current_provider=api_key_config.get("provider"),
+ current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
@@ -602,7 +625,7 @@ class AgentRunService:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
- processed_files = await multimodal_service.process_files(user_id, files)
+ processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
# 7. 知识库检索
@@ -661,7 +684,10 @@ class AgentRunService:
})
},
files=files,
- audio_url=audio_url
+ processed_files=processed_files,
+ audio_url=audio_url,
+ provider=api_key_config.get("provider"),
+ is_omni=api_key_config.get("is_omni", False)
)
response = {
@@ -676,8 +702,9 @@ class AgentRunService:
"suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [],
- "citations": self._filter_citations(features_config, result.get("citations", [])),
+ "citations": self._filter_citations(features_config, citations_collector),
"audio_url": audio_url,
+ "audio_status": "pending"
}
logger.info(
@@ -785,7 +812,8 @@ class AgentRunService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
- tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
+ kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
+ tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
@@ -830,8 +858,9 @@ class AgentRunService:
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
- api_config=model_info,
- max_history=memory_config.get("max_history", 10)
+ max_history=memory_config.get("max_history", 10),
+ current_provider=api_key_config.get("provider"),
+ current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
@@ -840,7 +869,7 @@ class AgentRunService:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
- processed_files = await multimodal_service.process_files(user_id, files)
+ processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
# 7. 知识库检索
@@ -909,10 +938,13 @@ class AgentRunService:
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
},
files=files,
- audio_url=stream_audio_url
+ processed_files=processed_files,
+ audio_url=stream_audio_url,
+ provider=api_key_config.get("provider"),
+ is_omni=api_key_config.get("is_omni", False)
)
- # 12. 发送结束事件(包含 suggested_questions 和 tts)
+ # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
end_data: Dict[str, Any] = {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
@@ -923,7 +955,18 @@ class AgentRunService:
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = stream_audio_url
- end_data["citations"] = self._filter_citations(features_config, [])
+ # 检查TTS是否已完成(非阻塞,不取消任务)
+ audio_status = "pending"
+ if tts_task is not None and tts_task.done():
+ # 任务已完成,检查是否有异常
+ try:
+ tts_task.result()
+ audio_status = "completed"
+ except Exception as e:
+ logger.warning(f"TTS任务异常: {e}")
+ audio_status = "failed"
+ end_data["audio_status"] = audio_status if stream_audio_url else None
+ end_data["citations"] = self._filter_citations(features_config, citations_collector)
yield self._format_sse_event("end", end_data)
logger.info(
@@ -1119,14 +1162,17 @@ class AgentRunService:
async def _load_conversation_history(
self,
conversation_id: str,
- api_config: ModelInfo | None = None,
- max_history: int = 10
+ max_history: int = 10,
+ current_provider: Optional[str] = None,
+ current_is_omni: Optional[bool] = None
) -> List[Dict[str, str]]:
- """加载会话历史消息
+ """加载会话历史消息,并根据当前模型配置处理多模态文件
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
+ current_provider: 当前模型的provider
+ current_is_omni: 当前模型的is_omni
Returns:
List[Dict]: 历史消息列表
@@ -1138,7 +1184,8 @@ class AgentRunService:
history = await conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id),
max_history=max_history,
- api_config=api_config
+ current_provider=current_provider,
+ current_is_omni=current_is_omni
)
logger.debug(
@@ -1166,7 +1213,10 @@ class AgentRunService:
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
files: Optional[List[FileInput]] = None,
- audio_url: Optional[str] = None
+ processed_files: Optional[List[Dict[str, Any]]] = None,
+ audio_url: Optional[str] = None,
+ provider: Optional[str] = None,
+ is_omni: Optional[bool] = None
) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
@@ -1177,6 +1227,11 @@ class AgentRunService:
app_id: 应用ID(未使用,保留用于兼容性)
user_id: 用户ID(未使用,保留用于兼容性)
meta_data: token消耗
+ files: 原始文件输入
+ processed_files: 处理后的文件
+ audio_url: 音频URL
+ provider: 模型供应商
+ is_omni: 是否为全模态模型
"""
try:
from app.services.conversation_service import ConversationService
@@ -1186,15 +1241,24 @@ class AgentRunService:
# 保存消息(会话已经存在)
human_meta = {
- "files": []
+ "files": [],
+ "history_files": {}
}
if files:
for f in files:
- # url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({
"type": f.type,
"url": f.url
})
+
+ # 保存 history_files,包含 provider 和 is_omni 信息
+ if processed_files:
+ human_meta["history_files"] = {
+ "content": processed_files,
+ "provider": provider,
+ "is_omni": is_omni
+ }
+
# 保存用户消息
conversation_service.add_message(
conversation_id=conv_uuid,
@@ -1420,8 +1484,9 @@ class AgentRunService:
workspace_id: Optional[uuid.UUID] = None,
) -> tuple[Optional[str], Optional[asyncio.Task]]:
"""文本流式输入并行合成音频。
- 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
+ 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
调用方向 text_queue put 文本 chunk,结束时 put None。
+ 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
@@ -1808,6 +1873,7 @@ class AgentRunService:
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
+ "audio_status": result.get("audio_status"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None
@@ -1885,6 +1951,7 @@ class AgentRunService:
"results": [{
**r,
"audio_url": r.get("audio_url"),
+ "audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
@@ -2016,6 +2083,7 @@ class AgentRunService:
full_content = ""
returned_conversation_id = model_conversation_id
audio_url = None
+ audio_status = None
citations = []
suggested_questions = []
@@ -2074,6 +2142,7 @@ class AgentRunService:
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
+ audio_status = event_data.get("audio_status")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
@@ -2103,6 +2172,7 @@ class AgentRunService:
"message": full_content,
"elapsed_time": elapsed,
"audio_url": audio_url,
+ "audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None
@@ -2117,6 +2187,7 @@ class AgentRunService:
"elapsed_time": elapsed,
"message_length": len(full_content),
"audio_url": audio_url,
+ "audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time()
@@ -2253,6 +2324,7 @@ class AgentRunService:
"message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
+ "audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error")
diff --git a/api/app/services/file_storage_service.py b/api/app/services/file_storage_service.py
index 2ebc5d9a..5897936b 100644
--- a/api/app/services/file_storage_service.py
+++ b/api/app/services/file_storage_service.py
@@ -325,27 +325,30 @@ class FileStorageService:
)
raise
- async def get_file_url(self, file_key: str, expires: int = 3600) -> str:
+ async def get_file_url(
+ self,
+ file_key: str,
+ expires: int = 3600,
+ file_name: Optional[str] = None,
+ ) -> str:
"""
Get an access URL for a file.
Args:
file_key: The file key.
expires: URL validity period in seconds (default: 1 hour).
+ file_name: If set, adds Content-Disposition: attachment to force download.
Returns:
URL for accessing the file.
"""
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
-
try:
- url = await self.storage.get_url(file_key, expires)
+ url = await self.storage.get_url(file_key, expires, file_name=file_name)
logger.debug(f"File URL generated: file_key={file_key}")
return url
except Exception as e:
- logger.error(
- f"Error getting file URL: file_key={file_key}, error={str(e)}"
- )
+ logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
raise
diff --git a/api/app/services/generation_service.py b/api/app/services/generation_service.py
new file mode 100644
index 00000000..2505793c
--- /dev/null
+++ b/api/app/services/generation_service.py
@@ -0,0 +1,162 @@
+"""
+图片和视频生成服务
+
+提供统一的生成接口,支持多种 Provider
+"""
+from typing import Dict, Any, Optional
+from sqlalchemy.orm import Session
+import uuid
+
+from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
+from app.core.exceptions import BusinessException
+from app.core.error_codes import BizCode
+from app.models.models_model import ModelType
+from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
+from app.services.model_service import ModelApiKeyService
+
+
+class GenerationService:
+ """生成服务"""
+
+ def __init__(self, db: Session):
+ self.db = db
+
+ async def generate_image(
+ self,
+ model_config_id: str,
+ prompt: str,
+ size: Optional[str] = "2k",
+ **kwargs
+ ) -> Dict[str, Any]:
+ """
+ 生成图片
+
+ Args:
+ model_config_id: 模型配置ID
+ prompt: 提示词
+ size: 图片尺寸
+ **kwargs: 其他参数
+
+ Returns:
+ 生成结果
+ """
+ # 获取模型配置
+ model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
+ if not model_config:
+ raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
+
+ if model_config.type != ModelType.IMAGE:
+ raise BusinessException(
+ f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
+ code=BizCode.INVALID_PARAMETER
+ )
+
+ # 获取 API Key
+ api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
+ if not api_key_info:
+ raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
+
+ # 创建配置
+ config = RedBearModelConfig(
+ model_name=api_key_info.model_name,
+ provider=api_key_info.provider,
+ api_key=api_key_info.api_key,
+ base_url=api_key_info.api_base,
+ extra_params=api_key_info.config or {}
+ )
+
+ # 生成图片
+ generator = RedBearImageGenerator(config)
+ result = await generator.agenerate(prompt, size, **kwargs)
+
+ return result
+
+ async def generate_video(
+ self,
+ model_config_id: str,
+ prompt: str,
+ duration: Optional[int] = None,
+ **kwargs
+ ) -> Dict[str, Any]:
+ """
+ 生成视频
+
+ Args:
+ model_config_id: 模型配置ID
+ prompt: 提示词
+ duration: 视频时长(秒)
+ **kwargs: 其他参数
+
+ Returns:
+ 生成结果(包含任务ID)
+ """
+ # 获取模型配置
+ model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
+ if not model_config:
+ raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
+
+ if model_config.type != ModelType.VIDEO:
+ raise BusinessException(
+ f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
+ code=BizCode.INVALID_PARAMETER
+ )
+
+ # 获取 API Key
+ api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
+ if not api_key_info:
+ raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
+
+ # 创建配置
+ config = RedBearModelConfig(
+ model_name=api_key_info.model_name,
+ provider=api_key_info.provider,
+ api_key=api_key_info.api_key,
+ base_url=api_key_info.api_base,
+ extra_params=api_key_info.config or {}
+ )
+
+ # 生成视频
+ generator = RedBearVideoGenerator(config)
+ result = await generator.agenerate(prompt, duration, **kwargs)
+
+ return result
+
+ async def get_video_task_status(
+ self,
+ model_config_id: str,
+ task_id: str
+ ) -> Dict[str, Any]:
+ """
+ 查询视频生成任务状态
+
+ Args:
+ model_config_id: 模型配置ID
+ task_id: 任务ID
+
+ Returns:
+ 任务状态信息
+ """
+ # 获取模型配置
+ model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
+ if not model_config:
+ raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
+
+ # 获取 API Key
+ api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
+ if not api_key_info:
+ raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
+
+ # 创建配置
+ config = RedBearModelConfig(
+ model_name=api_key_info.model_name,
+ provider=api_key_info.provider,
+ api_key=api_key_info.api_key,
+ base_url=api_key_info.api_base,
+ extra_params=api_key_info.config or {}
+ )
+
+ # 查询任务状态
+ generator = RedBearVideoGenerator(config)
+ result = await generator.aget_task_status(task_id)
+
+ return result
diff --git a/api/app/services/home_page_service.py b/api/app/services/home_page_service.py
index 8326ad40..4e6bf664 100644
--- a/api/app/services/home_page_service.py
+++ b/api/app/services/home_page_service.py
@@ -94,29 +94,38 @@ class HomePageService:
@staticmethod
def load_version_introduction(version: str) -> Dict[str, Any]:
"""
- 从 JSON 文件加载对应版本的介绍
+ 加载对应版本的介绍(优先从数据库读取,fallback 到 JSON 文件)
:param version: 系统版本号(如 "0.2.0")
:return: 对应版本的详细介绍
"""
- # 2. 定义 JSON 文件路径(简化路径处理,保留绝对路径调试特性)
+ from copy import deepcopy
+ from app.db import SessionLocal
+ from app.repositories.home_page_repository import HomePageRepository
+
+ result = deepcopy(HomePageService.DEFAULT_RETURN_DATA)
+
+ try:
+ db = SessionLocal()
+ try:
+ db_result = HomePageRepository.get_version_introduction(db, version)
+ if db_result:
+ return db_result
+ finally:
+ db.close()
+ except Exception as e:
+ pass
+
json_abs_path = Path(__file__).parent.parent / "version_info.json"
json_abs_path = json_abs_path.resolve()
- # 3. 初始化返回结果(深拷贝默认模板,避免修改原常量)
- from copy import deepcopy
- result = deepcopy(HomePageService.DEFAULT_RETURN_DATA)
-
try:
- # 4. 简化文件存在性判断(合并逻辑,减少分支)
if not json_abs_path.exists():
result["message"] = f"版本介绍文件不存在:{json_abs_path}"
return result
- # 5. 读取并解析 JSON 文件(简化文件操作流程)
with open(json_abs_path, "r", encoding="utf-8") as f:
changelogs = json.load(f)
- # 6. 简化版本匹配逻辑,直接返回结果或更新提示信息
if version in changelogs:
return changelogs[version]
result["message"] = f"暂未查询到 {version} 版本的详细介绍"
diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py
index af9a04e2..289fd74c 100644
--- a/api/app/services/memory_agent_service.py
+++ b/api/app/services/memory_agent_service.py
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID
import redis
-from langchain_core.messages import AIMessage, HumanMessage
+from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
+from app.cache import InterestMemoryCache
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
-from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
+from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
+from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
+from app.services.memory_perceptual_service import MemoryPerceptualService
try:
from app.core.memory.utils.log.audit_logger import audit_logger
@@ -267,8 +270,16 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
- async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
- db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
+ async def write_memory(
+ self,
+ end_user_id: str,
+ messages: list[dict],
+ config_id: Optional[uuid.UUID] | int,
+ db: Session,
+ storage_type: str,
+ user_rag_memory_id: str,
+ language: str = "zh"
+ ) -> str:
"""
Process write operation with config_id
@@ -297,8 +308,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
- raise ValueError(
- f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
+ raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
+ f"Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -334,48 +345,58 @@ class MemoryAgentService:
raise ValueError(error_msg)
+ perceptual_serivce = MemoryPerceptualService(db)
+ for message in messages:
+ message["file_content"] = []
+ for file in (message.get("files") or []):
+ file_object = await perceptual_serivce.generate_perceptual_memory(
+ end_user_id=end_user_id,
+ memory_config=memory_config,
+ file=FileInput(**file)
+ )
+ if file_object is None:
+ continue
+ message["file_content"].append((file_object, file["type"]))
+ logger.info(messages)
+
+ message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
- message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
- result = await write_rag(end_user_id, message_text, user_rag_memory_id)
- return result
+ await write_rag(end_user_id, message_text, user_rag_memory_id)
+ return "success"
else:
- async with make_write_graph() as graph:
- config = {"configurable": {"thread_id": end_user_id}}
- # Convert structured messages to LangChain messages
- langchain_messages = []
- for msg in messages:
- if msg['role'] == 'user':
- langchain_messages.append(HumanMessage(content=msg['content']))
- elif msg['role'] == 'assistant':
- langchain_messages.append(AIMessage(content=msg['content']))
- print(100 * '-')
- print(langchain_messages)
- print(100 * '-')
- # 初始状态 - 包含所有必要字段
- initial_state = {
- "messages": langchain_messages,
- "end_user_id": end_user_id,
- "memory_config": memory_config,
- "language": language
+ await write_neo4j(
+ end_user_id=end_user_id,
+ messages=messages,
+ memory_config=memory_config,
+ ref_id='',
+ language=language
+ )
+ for lang in ["zh", "en"]:
+ deleted = await InterestMemoryCache.delete_interest_distribution(
+ end_user_id, lang
+ )
+ if deleted:
+ logger.info(
+ f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
+ for message in messages:
+ message["file_content"] = [
+ perceptual[0].file_path for perceptual in message["file_content"]
+ ]
+ return self.writer_messages_deal(
+ "success",
+ start_time,
+ end_user_id,
+ config_id,
+ message_text,
+ {
+ "status": "success",
+ "data": messages,
+ "config_id": memory_config.config_id,
+ "config_name": memory_config.config_name
}
-
- # 获取节点更新信息
- async for update_event in graph.astream(
- initial_state,
- stream_mode="updates",
- config=config
- ):
- for node_name, node_data in update_event.items():
- if 'save_neo4j' == node_name:
- massages = node_data
- massagesstatus = massages.get('write_result')['status']
- contents = massages.get('write_result')
- # Convert messages back to string for logging
- message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
- return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
- contents)
+ )
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
@@ -586,7 +607,7 @@ class MemoryAgentService:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
- if retrieved_content == []:
+ if not retrieved_content:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py
index 01bc6267..9282fc28 100644
--- a/api/app/services/memory_api_service.py
+++ b/api/app/services/memory_api_service.py
@@ -28,7 +28,7 @@ class MemoryAPIService:
2. Maps end_user_id to end_user_id for memory operations
3. Delegates to MemoryAgentService for actual memory read/write operations
"""
-
+
def __init__(self, db: Session):
"""Initialize MemoryAPIService.
@@ -36,11 +36,11 @@ class MemoryAPIService:
db: SQLAlchemy database session
"""
self.db = db
-
+
def validate_end_user(
- self,
- end_user_id: str,
- workspace_id: uuid.UUID
+ self,
+ end_user_id: str,
+ workspace_id: uuid.UUID
) -> EndUser:
"""Validate that end_user exists and belongs to the workspace.
@@ -56,7 +56,7 @@ class MemoryAPIService:
BusinessException: If end_user not in authorized workspace
"""
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
-
+
# Query end_user by ID
try:
end_user_uuid = uuid.UUID(end_user_id)
@@ -66,7 +66,7 @@ class MemoryAPIService:
message=f"Invalid end_user_id format: {end_user_id}",
code=BizCode.INVALID_PARAMETER
)
-
+
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
if not end_user:
@@ -75,13 +75,13 @@ class MemoryAPIService:
resource_type="EndUser",
resource_id=end_user_id
)
-
+
# Verify end_user belongs to the workspace via App relationship
app = self.db.query(App).filter(
App.id == end_user.app_id,
App.is_active.is_(True)
).first()
-
+
if not app:
logger.warning(f"App not found for end_user: {end_user_id}")
# raise ResourceNotFoundException(
@@ -99,7 +99,7 @@ class MemoryAPIService:
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
# code=BizCode.FORBIDDEN
# )
-
+
logger.info(f"End user {end_user_id} validated successfully")
return end_user
@@ -125,13 +125,13 @@ class MemoryAPIService:
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
async def write_memory(
- self,
- workspace_id: uuid.UUID,
- end_user_id: str,
- message: str,
- config_id: str,
- storage_type: str = "neo4j",
- user_rag_memory_id: Optional[str] = None,
+ self,
+ workspace_id: uuid.UUID,
+ end_user_id: str,
+ message: str,
+ config_id: str,
+ storage_type: str = "neo4j",
+ user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Write memory with validation.
@@ -154,13 +154,13 @@ class MemoryAPIService:
BusinessException: If end_user not in authorized workspace or write fails
"""
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
-
+
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
-
+
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
-
+
try:
# Delegate to MemoryAgentService
# Convert string message to list[dict] format expected by MemoryAgentService
@@ -171,11 +171,11 @@ class MemoryAPIService:
config_id=config_id,
db=self.db,
storage_type=storage_type,
- user_rag_memory_id=user_rag_memory_id or ""
+ user_rag_memory_id=user_rag_memory_id or "",
)
-
+
logger.info(f"Memory write successful for end_user: {end_user_id}")
-
+
# result may be a string "success" or a dict with a "status" key
# Preserve the full dict so callers don't silently lose extra fields
# (e.g. error codes, metadata) returned by MemoryAgentService.
@@ -189,7 +189,7 @@ class MemoryAPIService:
"status": result if isinstance(result, str) else "success",
"end_user_id": end_user_id,
}
-
+
except ConfigurationError as e:
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
raise BusinessException(
@@ -204,16 +204,16 @@ class MemoryAPIService:
message=f"Memory write failed: {str(e)}",
code=BizCode.MEMORY_WRITE_FAILED
)
-
+
async def read_memory(
- self,
- workspace_id: uuid.UUID,
- end_user_id: str,
- message: str,
- search_switch: str = "0",
- config_id: str = "",
- storage_type: str = "neo4j",
- user_rag_memory_id: Optional[str] = None,
+ self,
+ workspace_id: uuid.UUID,
+ end_user_id: str,
+ message: str,
+ search_switch: str = "0",
+ config_id: str = "",
+ storage_type: str = "neo4j",
+ user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Read memory with validation.
@@ -237,14 +237,13 @@ class MemoryAPIService:
BusinessException: If end_user not in authorized workspace or read fails
"""
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
-
+
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
-
+
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
-
try:
# Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory(
@@ -257,15 +256,15 @@ class MemoryAPIService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or ""
)
-
+
logger.info(f"Memory read successful for end_user: {end_user_id}")
-
+
return {
"answer": result.get("answer", ""),
"intermediate_outputs": result.get("intermediate_outputs", []),
"end_user_id": end_user_id
}
-
+
except ConfigurationError as e:
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
raise BusinessException(
@@ -282,8 +281,8 @@ class MemoryAPIService:
)
def list_memory_configs(
- self,
- workspace_id: uuid.UUID,
+ self,
+ workspace_id: uuid.UUID,
) -> Dict[str, Any]:
"""List all memory configs for a workspace.
diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py
index a3751c07..66c110b1 100644
--- a/api/app/services/memory_config_service.py
+++ b/api/app/services/memory_config_service.py
@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
if isinstance(config_id, uuid.UUID):
return config_id
-
+
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
@@ -52,26 +52,30 @@ def _validate_config_id(config_id, db: Session = None):
field_name="config_id",
invalid_value=config_id,
)
- # 如果提供了数据库会话,尝试通过 user_id 查询 config_id
+ # 如果提供了数据库会话,尝试通过 config_id_old 查询 config_id
if db is not None:
- # 查询 user_id 匹配的记录
- stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == str(config_id))
+ # 查询 config_id_old 匹配的记录
+ stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == config_id)
result = db.execute(stmt).scalars().first()
if result:
- logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
+ logger.info(f"Found config_id {result.config_id} for config_id_old {config_id}")
return result.config_id
- return config_id
+ raise InvalidConfigError(
+ f"未找到 config_id_old={config_id} 对应的配置",
+ field_name="config_id",
+ invalid_value=config_id,
+ )
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
-
+
# Try parsing as UUID first
try:
return uuid.UUID(config_id_stripped)
except ValueError:
pass
-
+
# Fall back to integer parsing
try:
parsed_id = int(config_id_stripped)
@@ -81,18 +85,22 @@ def _validate_config_id(config_id, db: Session = None):
field_name="config_id",
invalid_value=config_id,
)
-
+
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
- # 查询 user_id 匹配的记录
- stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
+ # 查询 config_id_old 匹配的记录
+ stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == parsed_id)
result = db.execute(stmt).scalars().first()
-
+
if result:
- logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
+ logger.info(f"Found config_id {result.config_id} for config_id_old {parsed_id}")
return result.config_id
- return parsed_id
+ raise InvalidConfigError(
+ f"未找到 config_id_old={parsed_id} 对应的配置",
+ field_name="config_id",
+ invalid_value=config_id,
+ )
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)",
@@ -154,10 +162,10 @@ class MemoryConfigService:
self.db = db
def load_memory_config(
- self,
- config_id: Optional[UUID] = None,
- workspace_id: Optional[UUID] = None,
- service_name: str = "MemoryConfigService",
+ self,
+ config_id: Optional[UUID] = None,
+ workspace_id: Optional[UUID] = None,
+ service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database with optional fallback.
@@ -194,14 +202,14 @@ class MemoryConfigService:
try:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
+ validated_config_id = None
if workspace_id:
- validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
-
+
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
@@ -210,7 +218,7 @@ class MemoryConfigService:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
-
+
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -233,7 +241,7 @@ class MemoryConfigService:
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
-
+
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
@@ -243,10 +251,10 @@ class MemoryConfigService:
# Helper function to validate model with workspace fallback
def _validate_model_with_fallback(
- model_id: str,
- model_type: str,
- workspace_default: str,
- required: bool = False
+ model_id: str,
+ model_type: str,
+ workspace_default: str,
+ required: bool = False
) -> tuple:
"""Validate model ID, falling back to workspace default if invalid.
@@ -275,7 +283,7 @@ class MemoryConfigService:
logger.warning(
f"{model_type} model validation failed, trying workspace default: {e}"
)
-
+
# Fallback to workspace default
if workspace_default:
try:
@@ -297,7 +305,7 @@ class MemoryConfigService:
logger.error(f"Workspace default {model_type} model also invalid: {e}")
if required:
raise
-
+
if required:
raise InvalidConfigError(
f"{model_type.title()} model is required but not configured",
@@ -306,7 +314,7 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id
)
-
+
return None, None
# Step 2: Validate embedding model with workspace fallback
@@ -343,6 +351,35 @@ class MemoryConfigService:
if memory_config.rerank_id or workspace.rerank:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
+ vision_uuid, vision_name = validate_and_resolve_model_id(
+ memory_config.vision_id,
+ "llm",
+ self.db,
+ workspace.tenant_id,
+ required=False,
+ config_id=validated_config_id,
+ workspace_id=workspace.id,
+ )
+
+ audio_uuid, audio_name = validate_and_resolve_model_id(
+ memory_config.audio_id,
+ "llm",
+ self.db,
+ workspace.tenant_id,
+ required=False,
+ config_id=validated_config_id,
+ workspace_id=workspace.id,
+ )
+
+ video_uuid, video_name = validate_and_resolve_model_id(
+ memory_config.video_id,
+ "llm",
+ self.db,
+ workspace.tenant_id,
+ required=False,
+ config_id=validated_config_id,
+ workspace_id=workspace.id,
+ )
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
@@ -356,6 +393,12 @@ class MemoryConfigService:
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
+ video_model_id=video_uuid,
+ video_model_name=video_name,
+ vision_model_id=vision_uuid,
+ vision_model_name=vision_name,
+ audio_model_id=audio_uuid,
+ audio_model_name=audio_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
@@ -364,24 +407,31 @@ class MemoryConfigService:
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
- enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
- enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
+ enable_llm_dedup_blockwise=bool(
+ memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
+ enable_llm_disambiguation=bool(
+ memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
- statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
- include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
- max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
+ statement_granularity=int(
+ memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
+ include_dialogue_context=bool(
+ memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
+ max_dialogue_context_chars=int(
+ memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
- pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
+ pruning_enabled=bool(
+ memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
- pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
+ pruning_threshold=float(
+ memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association
scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -448,9 +498,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
-
+
api_config: ModelApiKey = config.api_keys[0]
-
+
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -481,9 +531,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
-
+
api_config: ModelApiKey = config.api_keys[0]
-
+
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -571,25 +621,25 @@ class MemoryConfigService:
"""
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
-
+
if not memory_config.scene_id:
logger.debug("No scene_id configured, skipping ontology type fetch")
return None
-
+
try:
ontology_repo = OntologyClassRepository(self.db)
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
-
+
if not ontology_classes:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
return None
-
+
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
return ontology_types
-
+
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
@@ -598,8 +648,8 @@ class MemoryConfigService:
return None
def get_workspace_default_config(
- self,
- workspace_id: UUID
+ self,
+ workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config.
@@ -613,19 +663,19 @@ class MemoryConfigService:
Optional[MemoryConfigModel]: Default config or None if no configs exist
"""
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
-
+
if not config:
logger.warning(
"No active memory config found for workspace fallback",
extra={"workspace_id": str(workspace_id)}
)
-
+
return config
def get_config_with_fallback(
- self,
- memory_config_id: Optional[UUID],
- workspace_id: UUID
+ self,
+ memory_config_id: Optional[UUID],
+ workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default.
@@ -644,13 +694,13 @@ class MemoryConfigService:
"No memory config ID provided, using workspace default",
extra={"workspace_id": str(workspace_id)}
)
-
+
config = MemoryConfigRepository.get_with_fallback(
self.db,
memory_config_id,
workspace_id
)
-
+
if not config and memory_config_id:
logger.warning(
"Memory config not found, falling back to workspace default",
@@ -659,13 +709,13 @@ class MemoryConfigService:
"workspace_id": str(workspace_id)
}
)
-
+
return config
def delete_config(
- self,
- config_id: UUID | int,
- force: bool = False
+ self,
+ config_id: UUID | int,
+ force: bool = False
) -> dict:
"""Delete memory config with protection against in-use configs.
@@ -687,7 +737,7 @@ class MemoryConfigService:
from app.core.exceptions import ResourceNotFoundException
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.end_user_repository import EndUserRepository
-
+
# 处理旧格式 int 类型的 config_id
if isinstance(config_id, int):
logger.warning(
@@ -699,11 +749,11 @@ class MemoryConfigService:
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
"legacy_int_id": config_id
}
-
+
config = self.db.get(MemoryConfigModel, config_id)
if not config:
raise ResourceNotFoundException("MemoryConfig", str(config_id))
-
+
# Check if this is the default config - default configs cannot be deleted
if config.is_default:
logger.warning(
@@ -715,11 +765,11 @@ class MemoryConfigService:
"message": "默认配置不允许删除",
"is_default": True
}
-
+
# Use repository to count connected end users
end_user_repo = EndUserRepository(self.db)
connected_count = end_user_repo.count_by_memory_config_id(config_id)
-
+
if connected_count > 0 and not force:
logger.warning(
"Attempted to delete memory config with connected end users",
@@ -728,18 +778,18 @@ class MemoryConfigService:
"connected_count": connected_count
}
)
-
+
return {
"status": "warning",
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
"connected_count": connected_count,
"force_required": True
}
-
+
# Force delete: use repository to clear end user references first
if connected_count > 0 and force:
cleared_count = end_user_repo.clear_memory_config_id(config_id)
-
+
logger.warning(
"Force deleting memory config, clearing end user references",
extra={
@@ -747,11 +797,11 @@ class MemoryConfigService:
"cleared_end_users": cleared_count
}
)
-
+
try:
self.db.delete(config)
self.db.commit()
-
+
logger.info(
"Memory config deleted",
extra={
@@ -760,16 +810,16 @@ class MemoryConfigService:
"affected_users": connected_count
}
)
-
+
return {
"status": "success",
"message": "记忆配置删除成功",
"affected_users": connected_count
}
-
+
except IntegrityError as e:
self.db.rollback()
-
+
# Handle foreign key violation gracefully
error_str = str(e.orig) if e.orig else str(e)
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
@@ -785,7 +835,7 @@ class MemoryConfigService:
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
"force_required": True
}
-
+
# Re-raise other integrity errors
logger.error(
"Delete failed due to integrity error",
@@ -800,9 +850,9 @@ class MemoryConfigService:
# ==================== 记忆配置提取方法 ====================
def extract_memory_config_id(
- self,
- app_type: str,
- config: dict
+ self,
+ app_type: str,
+ config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id(根据应用类型分发)
@@ -827,9 +877,26 @@ class MemoryConfigService:
logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}")
return None, False
+ def _resolve_config_id_old(self, config_id_old: int) -> Optional[uuid.UUID]:
+ """通过 config_id_old 查询对应的 UUID config_id。
+
+ Args:
+ config_id_old: 旧格式的整数配置ID
+
+ Returns:
+ 对应的 UUID config_id,未找到返回 None
+ """
+ from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
+ result = self.db.query(MemoryConfigModel).filter(
+ MemoryConfigModel.config_id_old == config_id_old
+ ).first()
+ if result:
+ return result.config_id
+ return None
+
def _extract_memory_config_id_from_agent(
- self,
- config: dict
+ self,
+ config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Agent 应用配置中提取 memory_config_id
@@ -858,10 +925,11 @@ class MemoryConfigService:
elif isinstance(memory_value, str):
# Check if it's a numeric string (legacy int format)
if memory_value.isdigit():
- logger.warning(
- f"Agent 配置中 memory_config_id 为旧格式 int 字符串,将使用工作空间默认配置: "
- f"value={memory_value}"
- )
+ resolved = self._resolve_config_id_old(int(memory_value))
+ if resolved:
+ logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}")
+ return resolved, False
+ logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置")
return None, True
try:
return uuid.UUID(memory_value), False
@@ -869,11 +937,11 @@ class MemoryConfigService:
logger.warning(f"Invalid UUID string: {memory_value}")
return None, False
elif isinstance(memory_value, int):
- # 旧数据存储为 int,需要回退到工作空间默认配置
- logger.warning(
- f"Agent 配置中 memory_config_id 为旧格式 int,将使用工作空间默认配置: "
- f"value={memory_value}"
- )
+ resolved = self._resolve_config_id_old(memory_value)
+ if resolved:
+ logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}")
+ return resolved, False
+ logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置")
return None, True
else:
logger.warning(
@@ -888,8 +956,8 @@ class MemoryConfigService:
return None, False
def _extract_memory_config_id_from_workflow(
- self,
- config: dict
+ self,
+ config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Workflow 应用配置中提取 memory_config_id
@@ -905,14 +973,14 @@ class MemoryConfigService:
- is_legacy_int: 是否检测到旧格式 int 数据
"""
nodes = config.get("nodes", [])
-
+
for node in nodes:
node_type = node.get("type", "")
-
+
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
config_id = node.get("config", {}).get("config_id")
-
+
if config_id:
try:
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
@@ -921,10 +989,16 @@ class MemoryConfigService:
elif isinstance(config_id, str):
return uuid.UUID(config_id), False
elif isinstance(config_id, int):
- # 旧数据存储为 int,需要回退到工作空间默认配置
+ resolved = self._resolve_config_id_old(config_id)
+ if resolved:
+ logger.info(
+ f"Resolved workflow legacy config_id_old={config_id} to config_id={resolved}: "
+ f"node_id={node.get('id')}, node_type={node_type}"
+ )
+ return resolved, False
logger.warning(
- f"工作流记忆节点 config_id 为旧格式 int,将使用工作空间默认配置: "
- f"node_id={node.get('id')}, node_type={node_type}, value={config_id}"
+ f"未找到工作流记忆节点 config_id_old={config_id} 对应的配置,将使用工作空间默认配置: "
+ f"node_id={node.get('id')}, node_type={node_type}"
)
return None, True
else:
@@ -937,6 +1011,6 @@ class MemoryConfigService:
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
f"node_type={node_type}, error={str(e)}"
)
-
+
logger.debug("工作流配置中未找到记忆节点")
return None, False
diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py
index a0bcc1a1..11118571 100644
--- a/api/app/services/memory_forget_service.py
+++ b/api/app/services/memory_forget_service.py
@@ -315,6 +315,12 @@ class MemoryForgetService:
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
+ # 如果参数为 None,使用配置中的默认值
+ if max_merge_batch_size is None:
+ max_merge_batch_size = config.get('max_merge_batch_size', 100)
+ if min_days_since_access is None:
+ min_days_since_access = config.get('min_days_since_access', 30)
+
# 记录执行开始时间
execution_time = datetime.now()
diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py
index b8961d33..523adadb 100644
--- a/api/app/services/memory_konwledges_server.py
+++ b/api/app/services/memory_konwledges_server.py
@@ -341,7 +341,7 @@ async def memory_konwledges_up(
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
- return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
+ return db_document
async def create_document_chunk(
@@ -350,7 +350,7 @@ async def create_document_chunk(
create_data: ChunkCreate,
db: Session,
current_user: User
-):
+) -> DocumentChunk:
"""
创建文档块
@@ -439,10 +439,10 @@ async def create_document_chunk(
db_document.chunk_num += 1
db.commit()
- return success(data=chunk, msg="文档块创建成功")
+ return chunk
-async def write_rag(end_user_id, message, user_rag_memory_id):
+async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
"""
将消息写入 RAG 知识库
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======', document)
api_logger.info(f"查找文档结果: document_id={document}")
+ create_chunks = ChunkCreate(content=message)
if document is not None:
# 文档已存在,直接添加新块
api_logger.info(f"文档已存在,添加新块: document_id={document}")
- create_chunks = ChunkCreate(content=message)
result = await create_document_chunk(
kb_id=kb_uuid,
document_id=uuid.UUID(document),
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else:
# 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
- result = await memory_konwledges_up(
+ document = await memory_konwledges_up(
kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id,
create_data=create_data,
db=db,
current_user=current_user
)
+ result = await create_document_chunk(
+ kb_id=kb_uuid,
+ document_id=document.id,
+ create_data=create_chunks,
+ db=db,
+ current_user=current_user
+ )
# 重新查询刚创建的文档ID
new_document_id = find_document_id_by_kb_and_filename(
db=db,
diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py
index 8a7c86e2..3ee238e2 100644
--- a/api/app/services/memory_perceptual_service.py
+++ b/api/app/services/memory_perceptual_service.py
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
-from app.models import FileMetadata
+from app.models import FileMetadata, ModelApiKey, ModelType
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
-from app.schemas import FileType
+from app.schemas import FileType, FileInput
+from app.schemas.memory_config_schema import MemoryConfig
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
AudioModal, Content, VideoModal, TextModal
)
from app.schemas.model_schema import ModelInfo
+from app.services.model_service import ModelApiKeyService
+from app.services.multimodal_service import MultimodalService
business_logger = get_business_logger()
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
+ def _get_mutlimodal_client(
+ self,
+ file_type: FileType,
+ config: MemoryConfig
+ ) -> tuple[RedBearLLM | None, ModelApiKey | None]:
+ model_config = None
+ if file_type == FileType.AUDIO:
+ model_config = ModelApiKeyService.get_available_api_key(
+ self.db,
+ config.audio_model_id
+ )
+ elif file_type == FileType.VIDEO:
+ model_config = ModelApiKeyService.get_available_api_key(
+ self.db,
+ config.video_model_id
+ )
+ elif file_type == FileType.DOCUMENT:
+ model_config = ModelApiKeyService.get_available_api_key(
+ self.db,
+ config.llm_model_id
+ )
+ elif file_type == FileType.IMAGE:
+ model_config = ModelApiKeyService.get_available_api_key(
+ self.db,
+ config.vision_model_id
+ )
+ llm = None
+ if model_config:
+ llm = RedBearLLM(
+ RedBearModelConfig(
+ model_name=model_config.model_name,
+ provider=model_config.provider,
+ api_key=model_config.api_key,
+ base_url=model_config.api_base,
+ is_omni=model_config.is_omni
+ )
+ )
+ return llm, model_config
+
async def generate_perceptual_memory(
self,
end_user_id: str,
- model_config: ModelInfo,
- file_type: str,
- file_url: str,
- file_message: dict,
+ memory_config: MemoryConfig,
+ file: FileInput
):
- memories = self.repository.get_by_url(file_url)
+ memories = self.repository.get_by_url(file.url)
if memories:
- business_logger.info(f"Perceptual memory already exists: {file_url}")
+ business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
- self.repository.create_perceptual_memory(
+ memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
@@ -219,20 +259,33 @@ class MemoryPerceptualService:
meta_data=memory_cache.meta_data
)
self.db.commit()
-
- return
- llm = RedBearLLM(RedBearModelConfig(
+ return memory
+ else:
+ for memory in memories:
+ if memory.end_user_id == uuid.UUID(end_user_id):
+ return memory
+ llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
+ multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
- base_url=model_config.api_base,
- is_omni=model_config.is_omni
- ), type=model_config.model_type)
+ api_base=model_config.api_base,
+ is_omni=model_config.is_omni,
+ capability=model_config.capability,
+ model_type=ModelType.LLM
+ ))
+ file_message = await multimodel_service.process_files(
+ files=[file]
+ )
+ if not file_message:
+ business_logger.warning(f"Unsupported file type {file}, model capability: {model_config.capability}")
+ return None
+ file_message = file_message[0]
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
- rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
+ rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
@@ -242,8 +295,22 @@ class MemoryPerceptualService:
]}
]
result = await llm.ainvoke(messages)
- content = json_repair.repair_json(result.content, return_objects=True)
- path = urlparse(file_url).path
+ content = result.content
+ final_output = ""
+ if isinstance(content, list):
+ for msg in content:
+ if isinstance(msg, dict):
+ final_output += msg.get("text", "")
+ elif isinstance(msg, str):
+ final_output += msg
+ elif isinstance(content, dict):
+ final_output += content.get("text", "")
+ elif isinstance(content, str):
+ final_output = content
+ else:
+ raise ValueError(f"Unexcept Model Output Type: {result.content}")
+ content = json_repair.repair_json(final_output, return_objects=True)
+ path = urlparse(file.url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
@@ -252,21 +319,21 @@ class MemoryPerceptualService:
stmt = select(FileMetadata).where(
FileMetadata.id == file_id
)
- file = self.db.execute(stmt).scalar_one_or_none()
+ file_obj = self.db.execute(stmt).scalar_one_or_none()
- if file:
- filename = file.file_name
- file_ext = file.file_ext
+ if file_obj:
+ filename = file_obj.file_name
+ file_ext = file_obj.file_ext
except ValueError:
business_logger.debug(f"Remote file, file_id={filename}")
if not file_ext:
- if file_type == FileType.AUDIO:
+ if file.type == FileType.AUDIO:
file_ext = ".mp3"
- elif file_type == FileType.VIDEO:
+ elif file.type == FileType.VIDEO:
file_ext = ".mp4"
- elif file_type == FileType.DOCUMENT:
+ elif file.type == FileType.DOCUMENT:
file_ext = ".txt"
- elif file_type == FileType.IMAGE:
+ elif file.type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
@@ -274,11 +341,11 @@ class MemoryPerceptualService:
"topic": content.get("topic"),
"domain": content.get("domain")
}
- if file_type in [FileType.IMAGE, FileType.VIDEO]:
+ if file.type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene", [])
}
- elif file_type in [FileType.DOCUMENT]:
+ elif file.type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count", 0),
"title": content.get("title", ""),
@@ -288,10 +355,10 @@ class MemoryPerceptualService:
file_modalities = {
"speaker_count": content.get("speaker_count", 0)
}
- self.repository.create_perceptual_memory(
+ memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
- perceptual_type=PerceptualType.trans_from_file_type(file_type),
- file_path=file_url,
+ perceptual_type=PerceptualType.trans_from_file_type(file.type),
+ file_path=file.url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary', ""),
@@ -301,3 +368,4 @@ class MemoryPerceptualService:
}
)
self.db.commit()
+ return memory
diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py
index 6e7c1ad4..58f3e8bd 100644
--- a/api/app/services/memory_storage_service.py
+++ b/api/app/services/memory_storage_service.py
@@ -11,9 +11,11 @@ import time
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional
+from dotenv import load_dotenv
+from sqlalchemy.orm import Session
+
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import (
- get_hot_memory_tags,
get_raw_tags_from_db,
filter_tags_with_llm,
)
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
)
from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message
-from dotenv import load_dotenv
-from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
@@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector()
class MemoryStorageService:
"""Service for memory storage operations"""
-
+
def __init__(self):
logger.info("MemoryStorageService initialized")
-
+
async def get_storage_info(self) -> dict:
"""
Example wrapper method - retrieves storage information
@@ -59,17 +59,17 @@ class MemoryStorageService:
Storage information dictionary
"""
logger.info("Getting storage info ")
-
+
# Empty wrapper - implement your logic here
result = {
"status": "active",
"message": "This is an example wrapper"
}
-
- return result
-
-class DataConfigService: # 数据配置服务类(PostgreSQL)
+ return result
+
+
+class DataConfigService: # 数据配置服务类(PostgreSQL)
"""Service layer for config params CRUD.
使用 SQLAlchemy ORM 进行数据库操作。
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
return data_list
# --- Create ---
- def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
+ def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
return None
# --- Delete ---
- def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
+ def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
success = MemoryConfigRepository.delete(self.db, key.config_id)
if not success:
raise ValueError("未找到配置")
return {"affected": 1}
# --- Update ---
- def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
+ def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
config = MemoryConfigRepository.update(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": 1}
- def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
+ def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
config = MemoryConfigRepository.update_extracted(self.db, update)
if not config:
raise ValueError("未找到配置")
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
# --- Read ---
- def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
+ def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
return result
# --- Read All ---
- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
+ def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
@@ -241,13 +241,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
except (ValueError, TypeError):
config_id_old = None
-
- if config_id_old:
- memory_config=config_id_old
- else:
- memory_config=config.config_id
config_dict = {
- "config_id": memory_config,
+ "config_id": str(config.config_id),
"config_name": config.config_name,
"config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
@@ -289,7 +284,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list)
-
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
"""
流式执行试运行,产生 SSE 格式的进度事件
@@ -311,14 +305,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
"""
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
-
+
try:
# 发出初始进度事件
yield format_sse_message("starting", {
"message": "开始试运行...",
"time": int(time.time() * 1000)
})
-
+
# 步骤 1: 配置加载和验证(数据库优先)
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
@@ -344,27 +338,28 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
# 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip()
- logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
+ logger.info(
+ f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else:
# 如果没有提供 custom_text,回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
- logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
+ logger.info(
+ f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else:
# 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
-
+
# 验证最终使用的文本不为空
if not dialogue_text:
raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)")
-
- logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
+ logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue()
-
+
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
"""
进度回调函数,将进度事件放入队列
@@ -375,14 +370,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
data: 可选的结果数据(用于传递节点执行结果)
"""
await progress_queue.put((stage, message, data))
-
+
# 步骤 3: 在后台任务中执行管线
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.services.pilot_run_service import run_pilot_extraction
-
- logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
+
+ logger.info(
+ f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
@@ -391,60 +387,60 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
language=language,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
-
+
# 标记管线完成
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
except Exception as e:
# 将异常放入队列
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
-
+
# 启动后台任务
pipeline_task = asyncio.create_task(run_pipeline())
-
+
# 步骤 4: 从队列中读取进度事件并发出
while True:
try:
# 等待进度事件,设置超时以检测客户端断开
stage, message, data = await asyncio.wait_for(
- progress_queue.get(),
+ progress_queue.get(),
timeout=0.5
)
-
+
# 检查特殊标记
if stage == "__PIPELINE_COMPLETE__":
break
elif stage == "__PIPELINE_ERROR__":
raise RuntimeError(message)
-
+
# 构建进度事件数据
progress_data = {
"message": message,
"time": int(time.time() * 1000)
}
-
+
# 如果有结果数据,添加到事件中
if data:
progress_data["data"] = data
-
+
# 发出进度事件,使用 stage 作为事件类型
yield format_sse_message(stage, progress_data)
-
+
except TimeoutError:
# 超时,继续等待(这允许检测客户端断开)
continue
-
+
# 等待管线任务完成
await pipeline_task
-
+
# 步骤 5: 读取提取结果
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
-
+
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
-
+
# 步骤 6: 计算本体覆盖率并合并到结果中
result_data = {
"config_id": cid,
@@ -460,15 +456,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
result_data["ontology_coverage"] = ontology_coverage
except Exception as cov_err:
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
-
+
yield format_sse_message("result", result_data)
-
+
# 步骤 7: 发出完成事件
yield format_sse_message("done", {
"message": "试运行完成",
"time": int(time.time() * 1000)
})
-
+
except asyncio.CancelledError:
# 客户端断开连接
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
@@ -483,11 +479,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
"time": int(time.time() * 1000)
})
-
async def _compute_ontology_coverage(
- self,
- extracted_result: Dict[str, Any],
- memory_config,
+ self,
+ extracted_result: Dict[str, Any],
+ memory_config,
) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
@@ -580,8 +575,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD)
-load_dotenv()
-_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
@@ -664,7 +657,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
- "total": 0,
+ "total": 0,
"distribution": [
{"type": "dialogue", "count": 0},
{"type": "chunk", "count": 0},
@@ -701,10 +694,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
)
return result
+
async def analytics_hot_memory_tags(
- db: Session,
- current_user: User,
- limit: int = 10
+ db: Session,
+ current_user: User,
+ limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取热门记忆标签,按数量排序并返回前N个
@@ -721,27 +715,27 @@ async def analytics_hot_memory_tags(
from app.services.memory_dashboard_service import get_workspace_end_users
# 使用 asyncio.to_thread 避免阻塞事件循环
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
-
+
if not end_users:
return []
-
+
# 步骤1: 收集所有用户的原始标签(不调用LLM)
connector = Neo4jConnector()
try:
all_raw_tags = []
for end_user in end_users:
raw_tags = await get_raw_tags_from_db(
- connector,
- str(end_user.id),
- limit=raw_limit,
+ connector,
+ str(end_user.id),
+ limit=raw_limit,
by_user=False
)
if raw_tags:
all_raw_tags.extend(raw_tags)
-
+
if not all_raw_tags:
return []
-
+
# 步骤2: 聚合相同标签的频率
tag_frequency_map = {}
for tag_name, frequency in all_raw_tags:
@@ -749,36 +743,36 @@ async def analytics_hot_memory_tags(
tag_frequency_map[tag_name] += frequency
else:
tag_frequency_map[tag_name] = frequency
-
+
# 步骤3: 按频率降序排序,取前raw_limit个
sorted_tags = sorted(
- tag_frequency_map.items(),
- key=lambda x: x[1],
+ tag_frequency_map.items(),
+ key=lambda x: x[1],
reverse=True
)[:raw_limit]
-
+
if not sorted_tags:
return []
-
+
# 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags]
-
+
# 使用第一个用户的end_user_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
-
+
# 步骤5: 根据LLM筛选结果构建最终列表(保留频率)
final_tags = []
for tag, freq in sorted_tags:
if tag in filtered_tag_names:
final_tags.append((tag, freq))
-
+
# 步骤6: 只返回前limit个
top_tags = final_tags[:limit]
-
+
return [{"name": t, "frequency": f} for t, f in top_tags]
-
+
finally:
await connector.close()
@@ -815,11 +809,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
source = "log"
total = (
- stats.get("chunk_count", 0)
- + stats.get("statements_count", 0)
- + stats.get("triplet_entities_count", 0)
- + stats.get("triplet_relations_count", 0)
- + stats.get("temporal_count", 0)
+ stats.get("chunk_count", 0)
+ + stats.get("statements_count", 0)
+ + stats.get("triplet_entities_count", 0)
+ + stats.get("triplet_relations_count", 0)
+ + stats.get("temporal_count", 0)
)
# 计算"最新一次活动多久前"(仅日志来源时有效)
@@ -845,5 +839,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data
-
-
diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py
index a7398504..b98674ba 100644
--- a/api/app/services/model_service.py
+++ b/api/app/services/model_service.py
@@ -154,10 +154,17 @@ class ModelConfigService:
}
elif model_type_lower == "embedding":
- # Embedding 模型验证(在线程中运行同步方法)
+ # Embedding 模型验证
+ # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
- vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
+
+ # 火山引擎使用 embed_batch,其他使用 embed_documents
+ if provider.lower() == "volcano":
+ vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
+ else:
+ vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
+
elapsed_time = time.time() - start_time
return {
@@ -193,6 +200,56 @@ class ModelConfigService:
},
"error": None
}
+
+ elif model_type_lower == "image":
+ # 图片生成模型验证
+ from app.core.models.generation import RedBearImageGenerator
+
+ generator = RedBearImageGenerator(model_config)
+ result = await generator.agenerate(
+ prompt="a cute panda",
+ size="2K"
+ )
+ elapsed_time = time.time() - start_time
+ logger.info(f"成功生成图片,结果: {result}")
+
+ return {
+ "valid": True,
+ "message": "图片生成模型配置验证成功",
+ "response": f"成功生成图片,结果: {result}",
+ "elapsed_time": elapsed_time,
+ "usage": {
+ "prompt_length": len("a cute panda"),
+ "image_count": 1
+ },
+ "error": None
+ }
+
+ elif model_type_lower == "video":
+ # 视频生成模型验证
+ from app.core.models.generation import RedBearVideoGenerator
+
+ generator = RedBearVideoGenerator(model_config)
+ result = await generator.agenerate(
+ prompt="a cute panda playing in bamboo forest",
+ duration=5
+ )
+ elapsed_time = time.time() - start_time
+
+ # 视频生成是异步任务,返回任务ID
+ task_id = result.get("task_id") if isinstance(result, dict) else None
+
+ return {
+ "valid": True,
+ "message": "视频生成模型配置验证成功",
+ "response": f"成功创建视频生成任务,任务ID: {task_id}",
+ "elapsed_time": elapsed_time,
+ "usage": {
+ "prompt_length": len("a cute panda playing in bamboo forest"),
+ "task_id": task_id
+ },
+ "error": None
+ }
else:
return {
diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py
index 6cb0a7f0..4cf3d89d 100644
--- a/api/app/services/multimodal_service.py
+++ b/api/app/services/multimodal_service.py
@@ -9,17 +9,15 @@
- OpenAI: 支持 URL 和 base64 格式
"""
import base64
+import csv
import io
-import uuid
+import json
import zipfile
-import chardet
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
-import csv
-import json
-
import PyPDF2
+import chardet
import httpx
import magic
import openpyxl
@@ -35,7 +33,6 @@ from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
-from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -297,6 +294,7 @@ PROVIDER_STRATEGIES = {
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
+ "volcano": OpenAIFormatStrategy,
}
@@ -342,92 +340,14 @@ class MultimodalService:
async def process_files(
self,
- end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
-
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
- end_user_id: 用户ID
files: 文件输入列表
- Returns:
- List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
- """
- if not files:
- return []
- if isinstance(end_user_id, uuid.UUID):
- end_user_id = str(end_user_id)
-
- # 获取对应的策略
- # dashscope 的 omni 模型使用 OpenAI 兼容格式
- if self.provider == "dashscope" and self.is_omni:
- strategy_class = OpenAIFormatStrategy
- else:
- strategy_class = PROVIDER_STRATEGIES.get(self.provider)
- if not strategy_class:
- logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
- strategy_class = DashScopeFormatStrategy
-
- result = []
- for idx, file in enumerate(files):
- strategy = strategy_class(file)
- if not file.url:
- file.url = await self.get_file_url(file)
- try:
- if file.type == FileType.IMAGE and "vision" in self.capability:
- is_support, content = await self._process_image(file, strategy)
- result.append(content)
- if is_support:
- self.write_perceptual_memory(end_user_id, file.type, file.url, content)
- elif file.type == FileType.DOCUMENT:
- is_support, content = await self._process_document(file, strategy)
- result.append(content)
- if is_support:
- self.write_perceptual_memory(end_user_id, file.type, file.url, content)
- elif file.type == FileType.AUDIO and "audio" in self.capability:
- is_support, content = await self._process_audio(file, strategy)
- result.append(content)
- if is_support:
- self.write_perceptual_memory(end_user_id, file.type, file.url, content)
- elif file.type == FileType.VIDEO and "video" in self.capability:
- is_support, content = await self._process_video(file, strategy)
- result.append(content)
- if is_support:
- self.write_perceptual_memory(end_user_id, file.type, file.url, content)
- else:
- logger.warning(f"不支持的文件类型: {file.type}")
- except Exception as e:
- logger.error(
- f"处理文件失败",
- extra={
- "file_index": idx,
- "file_type": file.type,
- "error": str(e)
- },
- exc_info=True
- )
- # 继续处理其他文件,不中断整个流程
- result.append({
- "type": "text",
- "text": f"[文件处理失败: {str(e)}]"
- })
-
- logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
- return result
-
- async def history_process_files(
- self,
- files: Optional[List[FileInput]],
- ) -> List[Dict[str, Any]]:
- """
- 处理文件列表,返回 LLM 可用的格式
-
- Args:
- files: 文件输入列表
-
Returns:
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
"""
@@ -483,17 +403,6 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
return result
- def write_perceptual_memory(
- self,
- end_user_id: str,
- file_type: str,
- file_url: str,
- file_message: dict
- ):
- """写入感知记忆"""
- if end_user_id and self.api_config:
- write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
-
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
"""
处理图片文件
diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py
index fc749157..4617946b 100644
--- a/api/app/services/pilot_run_service.py
+++ b/api/app/services/pilot_run_service.py
@@ -297,9 +297,12 @@ async def run_pilot_extraction(
chunk_nodes,
statement_nodes,
entity_nodes,
+ _,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
+ _,
+ _
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py
index 12e0c324..585fdd78 100644
--- a/api/app/services/user_memory_service.py
+++ b/api/app/services/user_memory_service.py
@@ -1887,7 +1887,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
"Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
- "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
+ "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
+ "Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
}
# 获取该节点类型的白名单字段
diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py
index e23b1ac3..3122d282 100644
--- a/api/app/services/user_service.py
+++ b/api/app/services/user_service.py
@@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User:
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
try:
- # 检查用户名是否已存在
- business_logger.debug(f"检查用户名是否已存在: {user.username}")
- db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
- if db_user_by_username:
- business_logger.warning(f"用户名已存在: {user.username}")
- raise BusinessException(
- "用户名已存在",
- code=BizCode.DUPLICATE_NAME,
- context={"username": user.username, "email": user.email}
- )
-
- # 检查邮箱是否已注册
+ # 检查邮箱是否已注册(邮箱保持唯一)
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:
@@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
)
try:
- # 检查用户名是否已存在
- business_logger.debug(f"检查用户名是否已存在: {user.username}")
- db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
- if db_user_by_username:
- business_logger.warning(f"用户名已存在: {user.username}")
- raise BusinessException(
- "用户名已存在",
- code=BizCode.DUPLICATE_NAME,
- context={
- "username": user.username,
- "email": user.email,
- "created_by": str(current_user.id)
- }
- )
-
- # 检查邮箱是否已注册
+ # 检查邮箱是否已注册(邮箱保持唯一)
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:
@@ -276,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user:
}
)
+ # 检查是否为租户联系人
+ from app.models.tenant_model import Tenants
+ tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first()
+ if tenant and tenant.contact_email and tenant.contact_email == db_user.email:
+ business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}")
+ raise BusinessException(
+ "该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员",
+ code=BizCode.FORBIDDEN,
+ context={
+ "user_id": str(user_id_to_deactivate),
+ "tenant_id": str(db_user.tenant_id)
+ }
+ )
+
# 停用用户
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
db_user.is_active = False
diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py
index 2b36c5ea..fd8f25f3 100644
--- a/api/app/services/workflow_import_service.py
+++ b/api/app/services/workflow_import_service.py
@@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get
from app.core.config import settings
from app.core.exceptions import BusinessException
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
-from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
+from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.schemas import AppCreate
from app.schemas.workflow_schema import WorkflowConfigCreate
@@ -46,7 +46,7 @@ class WorkflowImportService:
success=False,
temp_id=None,
workflow_id=None,
- errors=[UnsupportPlatform(platform=platform)]
+ errors=[UnsupportedPlatform(platform=platform)]
)
adapter = self.registry.get_adapter(platform, config)
diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py
index aee3d75f..c7d7f2b1 100644
--- a/api/app/services/workflow_service.py
+++ b/api/app/services/workflow_service.py
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
+from app.repositories import knowledge_repository
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
+from app.services.workspace_service import get_workspace_storage_type_without_auth
logger = logging.getLogger(__name__)
@@ -540,6 +542,25 @@ class WorkflowService:
mapped = internal_event
return mapped
+ def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
+ storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
+ user_rag_memory_id = ""
+ if storage_type == "rag":
+ knowledge = knowledge_repository.get_knowledge_by_name(
+ db=self.db,
+ name="USER_RAG_MERORY",
+ workspace_id=workspace_id
+ )
+ if knowledge:
+ user_rag_memory_id = str(knowledge.id)
+ else:
+ logger.warning(
+ f"No knowledge base named 'USER_RAG_MEMORY' found, "
+ f"workspace_id: {workspace_id}, will use neo4j storage"
+ )
+ storage_type = 'neo4j'
+ return storage_type, user_rag_memory_id
+
# ==================== 工作流执行 ====================
async def run(
@@ -607,6 +628,7 @@ class WorkflowService:
try:
files = await self._handle_file_input(payload.files)
+ storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
message_id = uuid.uuid4()
# 更新状态为运行中
@@ -631,7 +653,9 @@ class WorkflowService:
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
- user_id=payload.user_id
+ user_id=payload.user_id,
+ memory_storage_type=storage_type,
+ user_rag_memory_id=user_rag_memory_id
)
# 更新执行结果
if result.get("status") == "completed":
@@ -780,6 +804,7 @@ class WorkflowService:
try:
files = await self._handle_file_input(payload.files)
+ storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -801,6 +826,8 @@ class WorkflowService:
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id,
+ memory_storage_type=storage_type,
+ user_rag_memory_id=user_rag_memory_id
):
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py
index cefb8380..90b5cf65 100644
--- a/api/app/services/workspace_service.py
+++ b/api/app/services/workspace_service.py
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
def get_workspace_storage_type_without_auth(
db: Session,
workspace_id: uuid.UUID,
-) -> Optional[str]:
+) -> str:
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
Args:
diff --git a/api/app/tasks.py b/api/app/tasks.py
index 3a237d82..61736275 100644
--- a/api/app/tasks.py
+++ b/api/app/tasks.py
@@ -36,9 +36,11 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
)
from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
+from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
-from app.services.memory_agent_service import MemoryAgentService
+from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
+from app.services.memory_forget_service import MemoryForgetService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock
@@ -1073,9 +1075,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
-def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
- user_rag_memory_id: str,
- language: str = "zh") -> Dict[str, Any]:
+def write_message_task(
+ self,
+ end_user_id: str,
+ message: list[dict],
+ config_id: str | int,
+ storage_type: str,
+ user_rag_memory_id: str,
+ language: str = "zh"
+) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
Args:
end_user_id: Group ID for the memory agent (also used as end_user_id)
@@ -1091,7 +1099,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
Raises:
Exception on failure
"""
-
logger.info(
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
f"config_id={config_id} (type: {type(config_id).__name__}), "
@@ -1105,14 +1112,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try:
with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db)
- print(100 * '-')
- print(actual_config_id)
- print(100 * '-')
- logger.info(
- f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
+ logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
+ f"(type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e:
- logger.error(
- f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
+ logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
+ f"(type: {type(config_id).__name__}), error: {e}")
return {
"status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1151,8 +1155,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
- logger.info(
- f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
+ logger.info(f"[CELERY WRITE] Task completed successfully "
+ f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try:
@@ -1167,7 +1171,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
)
except Exception as _e:
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
-
return {
"status": "SUCCESS",
"result": result,
@@ -1859,7 +1862,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
@celery_app.task(
name="app.tasks.run_forgetting_cycle_task",
bind=True,
- ignore_result=True,
+ ignore_result=False, # 改为 False 以便在 Flower 中查看结果
max_retries=0,
acks_late=False,
time_limit=7200,
@@ -1867,68 +1870,77 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
)
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期
-
- 定期执行遗忘周期,识别并融合低激活值的知识节点。
-
- Args:
- config_id: 配置ID(可选,如果为None则使用默认配置)
-
- Returns:
- 包含任务执行结果的字典
+
+ 遍历所有终端用户,执行遗忘周期。
"""
start_time = time.time()
- async def _run() -> Dict[str, Any]:
- from app.services.memory_forget_service import MemoryForgetService
-
+ async def _process_users() -> Dict[str, Any]:
with get_db_context() as db:
- try:
- logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
+ end_users = db.query(EndUser).all()
+ if not end_users:
+ logger.info("没有终端用户,跳过遗忘周期")
+ return {"status": "SUCCESS", "message": "没有终端用户",
+ "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
+ "duration_seconds": time.time() - start_time}
- forget_service = MemoryForgetService()
+ logger.info(f"开始处理 {len(end_users)} 个终端用户的遗忘周期")
+ forget_service = MemoryForgetService()
+ total_merged = total_failed = processed_users = 0
+ failed_users = []
- # 运行遗忘周期
- # FIXME: MemeoryForgetService
- report = await forget_service.trigger_forgetting(
- db=db,
- end_user_id=None, # 处理所有组
- config_id=config_id
- )
+ for end_user in end_users:
+ try:
+ # 获取用户配置(自动回退到工作空间默认配置)
+ connected_config = get_end_user_connected_config(str(end_user.id), db)
+ user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
+
+ if not user_config_id:
+ failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
+ continue
- duration = time.time() - start_time
+ # 执行遗忘周期
+ report = await forget_service.trigger_forgetting_cycle(
+ db=db, end_user_id=str(end_user.id), config_id=user_config_id
+ )
+
+ total_merged += report.get('merged_count', 0)
+ total_failed += report.get('failed_count', 0)
+ processed_users += 1
+
+ logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
+
+ except Exception as e:
+ logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
+ failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
- logger.info(
- f"遗忘周期定时任务完成: "
- f"融合 {report['merged_count']} 对节点, "
- f"失败 {report['failed_count']} 对, "
- f"耗时 {duration:.2f} 秒"
- )
+ duration = time.time() - start_time
+ logger.info(f"遗忘周期完成: {processed_users}/{len(end_users)} 用户, "
+ f"融合 {total_merged} 对, 耗时 {duration:.2f}s")
- return {
- "status": "SUCCESS",
- "message": "遗忘周期执行成功",
- "report": report,
- "duration_seconds": duration
- }
-
- except Exception as e:
- duration = time.time() - start_time
- logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
-
- return {
- "status": "FAILED",
- "message": f"遗忘周期执行失败: {str(e)}",
- "duration_seconds": duration
- }
+ return {
+ "status": "SUCCESS",
+ "message": f"处理 {processed_users} 个用户",
+ "report": {
+ "merged_count": total_merged,
+ "failed_count": total_failed,
+ "processed_users": processed_users,
+ "total_users": len(end_users),
+ "failed_users": failed_users
+ },
+ "duration_seconds": duration
+ }
# 运行异步函数
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
try:
- result = loop.run_until_complete(_run())
- return result
- finally:
- loop.close()
+ return asyncio.run(_process_users())
+ except Exception as e:
+ logger.error(f"遗忘周期任务失败: {e}", exc_info=True)
+ return {
+ "status": "FAILED",
+ "message": f"任务失败: {str(e)}",
+ "duration_seconds": time.time() - start_time
+ }
# =============================================================================
@@ -2611,57 +2623,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
}
-@celery_app.task(
- name="app.tasks.write_perceptual_memory",
- bind=True,
- ignore_result=True,
- max_retries=0,
- acks_late=False,
- time_limit=3600,
- soft_time_limit=3300,
-)
-def write_perceptual_memory(
- self,
- end_user_id: str,
- model_api_config: dict,
- file_type: str,
- file_url: str,
- file_message: dict
-):
- """
- Write perceptual memory for a user into PostgreSQL and Neo4j.
-
- This task generates or updates the user's perceptual memory
- in the backend databases. It is intended to be executed asynchronously
- via Celery.
-
- Args:
- end_user_id (uuid.UUID): The unique identifier of the end user.
- model_api_config (ModelInfo): API configuration for the model
- used to generate perceptual memory.
- file_type (str): The file type
- file_url (url): The url of file
- file_message (dict): The file message containing details about the file
- to be processed.
-
- Returns:
- None
- """
- file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
- set_asyncio_event_loop()
- with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
- model_info = ModelInfo(**model_api_config)
- with get_db_context() as db:
- memory_perceptual_service = MemoryPerceptualService(db)
- return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
- end_user_id,
- model_info,
- file_type,
- file_url,
- file_message,
- ))
-
-
# =============================================================================
# 社区聚类补全任务(触发型)
# =============================================================================
@@ -2672,7 +2633,7 @@ def write_perceptual_memory(
ignore_result=False,
max_retries=0,
acks_late=False,
- time_limit=7200, # 2小时硬超时
+ time_limit=7200, # 2小时硬超时
soft_time_limit=6900,
)
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
@@ -2760,7 +2721,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
patch_fail = 0
for cid in incomplete_ids:
try:
- await engine._generate_community_metadata(cid, end_user_id)
+ await engine._generate_community_metadata([cid], end_user_id)
patch_ok += 1
except Exception as patch_err:
patch_fail += 1
@@ -2787,7 +2748,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
embedding_model_id=embedding_model_id,
)
- logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
+ logger.info(
+ f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
await engine.full_clustering(end_user_id)
initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
@@ -2810,12 +2772,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
}
try:
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
diff --git a/api/migrations/versions/05a681a6ca93_202603231611.py b/api/migrations/versions/05a681a6ca93_202603231611.py
new file mode 100644
index 00000000..5ab9c4de
--- /dev/null
+++ b/api/migrations/versions/05a681a6ca93_202603231611.py
@@ -0,0 +1,32 @@
+"""202603231611
+
+Revision ID: 05a681a6ca93
+Revises: 74b51dfece29
+Create Date: 2026-03-23 16:12:44.110292
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '05a681a6ca93'
+down_revision: Union[str, None] = '74b51dfece29'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_index(op.f('ix_users_username'), table_name='users')
+ op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_index(op.f('ix_users_username'), table_name='users')
+ op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/1ea8fe97b5b7_202603252115.py b/api/migrations/versions/1ea8fe97b5b7_202603252115.py
new file mode 100644
index 00000000..1f0df3e7
--- /dev/null
+++ b/api/migrations/versions/1ea8fe97b5b7_202603252115.py
@@ -0,0 +1,42 @@
+"""202603252115
+
+Revision ID: 1ea8fe97b5b7
+Revises: e28bcc212da5
+Create Date: 2026-03-25 21:14:41.825048
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '1ea8fe97b5b7'
+down_revision: Union[str, None] = 'e28bcc212da5'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True))
+ op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True))
+ op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True))
+ op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True))
+ op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True))
+ op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True))
+ op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('tenants', 'status')
+ op.drop_column('tenants', 'api_ops_rate_limit')
+ op.drop_column('tenants', 'plan_expired_at')
+ op.drop_column('tenants', 'plan')
+ op.drop_column('tenants', 'contact_phone')
+ op.drop_column('tenants', 'contact_email')
+ op.drop_column('tenants', 'contact_name')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/adaefcbe2aa1_202603261630.py b/api/migrations/versions/adaefcbe2aa1_202603261630.py
new file mode 100644
index 00000000..b8235dd7
--- /dev/null
+++ b/api/migrations/versions/adaefcbe2aa1_202603261630.py
@@ -0,0 +1,32 @@
+"""202603261630
+
+Revision ID: adaefcbe2aa1
+Revises: 1ea8fe97b5b7
+Create Date: 2026-03-26 16:27:17.590077
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = 'adaefcbe2aa1'
+down_revision: Union[str, None] = '1ea8fe97b5b7'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('tenants', sa.Column('feature_billing', sa.Boolean(), server_default='false', nullable=False, comment='是否启用收费管理菜单'))
+ op.add_column('tenants', sa.Column('feature_user_management', sa.Boolean(), server_default='false', nullable=False, comment='是否启用用户管理菜单'))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('tenants', 'feature_user_management')
+ op.drop_column('tenants', 'feature_billing')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/e28bcc212da5_202603241530.py b/api/migrations/versions/e28bcc212da5_202603241530.py
new file mode 100644
index 00000000..00173522
--- /dev/null
+++ b/api/migrations/versions/e28bcc212da5_202603241530.py
@@ -0,0 +1,34 @@
+"""202603241530
+
+Revision ID: e28bcc212da5
+Revises: 05a681a6ca93
+Create Date: 2026-03-24 15:32:14.461480
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = 'e28bcc212da5'
+down_revision: Union[str, None] = '05a681a6ca93'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('memory_config', sa.Column('vision_id', sa.String(), nullable=True, comment='视觉模型配置ID'))
+ op.add_column('memory_config', sa.Column('audio_id', sa.String(), nullable=True, comment='语音模型配置ID'))
+ op.add_column('memory_config', sa.Column('video_id', sa.String(), nullable=True, comment='视频模型配置ID'))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('memory_config', 'video_id')
+ op.drop_column('memory_config', 'audio_id')
+ op.drop_column('memory_config', 'vision_id')
+ # ### end Alembic commands ###
diff --git a/api/pyproject.toml b/api/pyproject.toml
index e6fddea8..8ced574c 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -147,6 +147,7 @@ dependencies = [
"modelscope>=1.34.0",
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
"python-magic-bin>=0.4.14; sys_platform=='win32'",
+ "volcengine-python-sdk[ark]==5.0.19"
]
[tool.pytest.ini_options]
diff --git a/web/package.json b/web/package.json
index db6a8408..0284f397 100644
--- a/web/package.json
+++ b/web/package.json
@@ -30,7 +30,7 @@
"@lexical/list": "^0.39.0",
"@lexical/react": "^0.39.0",
"@lexical/rich-text": "^0.39.0",
- "antd": "^5.27.4",
+ "antd": "^5.29.2",
"axios": "^1.12.2",
"clsx": "^2.1.1",
"codemirror": "^6.0.2",
diff --git a/web/src/App.tsx b/web/src/App.tsx
index 1d298358..a10f9409 100644
--- a/web/src/App.tsx
+++ b/web/src/App.tsx
@@ -21,7 +21,6 @@ import { useTranslation } from 'react-i18next';
import { lightTheme } from './styles/antdThemeConfig.ts'
import router from './routes';
import { useI18n } from '@/store/locale'
-import LayoutBg from '@/components/Layout/LayoutBg'
import dayjs from 'dayjs'
import 'dayjs/locale/en'
import 'dayjs/locale/zh-cn'
@@ -61,7 +60,6 @@ function App() {
theme={lightTheme}
>
-
}>
{
}
})
}
+// Get workspace API call statistics
+export const getWorkspaceApiStatistics = (data: { start_date: number; end_date: number; }) => {
+ return request.get(`/apps/workspace/api-statistics`, data)
+}
// Export application
export const appExport = (app_id: string, appName: string, data?: { release_id: string }) => {
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
@@ -165,4 +169,9 @@ export const cancelShare = (app_id: string, target_workspace_id?: string) => {
export const cancelSpaceShare = (target_workspace_id?: string) => {
return request.delete(`/apps/share/${target_workspace_id}`)
}
-
+// Application conversation logs
+export const getAppLogsUrl = (app_id: string) => `/apps/${app_id}/logs`
+// Get full conversation message history
+export const getAppLogDetail = (app_id: string, conversation_id: string) => {
+ return request.get(`/apps/${app_id}/logs/${conversation_id}`)
+}
\ No newline at end of file
diff --git a/web/src/api/fileStorage.ts b/web/src/api/fileStorage.ts
index ce133565..83f5b212 100644
--- a/web/src/api/fileStorage.ts
+++ b/web/src/api/fileStorage.ts
@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 13:59:56
* @Last Modified by: ZhaoYing
- * @Last Modified time: 2026-02-09 16:24:05
+ * @Last Modified time: 2026-03-23 18:05:43
*/
import { request, API_PREFIX } from '@/utils/request'
@@ -32,4 +32,13 @@ export const deleteFile = (fileId: string) => {
}
export const shareFileUploadUrlWithoutApiPrefix = `/storage/share/files`
-export const shareFileUploadUrl = `${API_PREFIX}${shareFileUploadUrlWithoutApiPrefix}`
\ No newline at end of file
+export const shareFileUploadUrl = `${API_PREFIX}${shareFileUploadUrlWithoutApiPrefix}`
+
+// Get file info
+export const getFileInfoByUrl = (url: string) => {
+ return request.get('/storage/files/info-by-url', {url})
+}
+// Get file status
+export const getFileStatusById = (file_id: string) => {
+ return request.get(`/storage/files/${file_id}/status`)
+}
\ No newline at end of file
diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts
index 9a464893..1ec2d7dc 100644
--- a/web/src/api/memory.ts
+++ b/web/src/api/memory.ts
@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 14:00:06
* @Last Modified by: ZhaoYing
- * @Last Modified time: 2026-03-19 18:35:10
+ * @Last Modified time: 2026-03-24 17:48:01
*/
import { request } from '@/utils/request'
import type { AxiosRequestConfig } from 'axios'
@@ -87,12 +87,13 @@ export const getUserSummary = (end_user_id: string) => {
export const getNodeStatistics = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/node_statistics`, { end_user_id })
}
-// Basic information
-export const getEndUserProfile = (end_user_id: string) => {
- return request.get(`/memory-storage/read_end_user/profile`, { end_user_id })
+// 查询用户别名及信息
+export const getEndUserInfo = (end_user_id: string) => {
+ return request.get(`/memory-storage/end_user_info`, { end_user_id })
}
-export const updatedEndUserProfile = (values: EndUser) => {
- return request.post(`/memory-storage/updated_end_user/profile`, values)
+// 更新用户别名及信息
+export const updatedEndUserInfo = (values: EndUser) => {
+ return request.post(`/memory-storage/end_user_info/updated`, values)
}
// User Memory - Relationship network
export const getMemorySearchEdges = (end_user_id: string, config?: AxiosRequestConfig) => {
diff --git a/web/src/assets/font/MiSans/MiSans-Bold.woff2 b/web/src/assets/font/MiSans/MiSans-Bold.woff2
new file mode 100644
index 00000000..e4a21bee
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Bold.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Demibold.woff2 b/web/src/assets/font/MiSans/MiSans-Demibold.woff2
new file mode 100644
index 00000000..70205afb
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Demibold.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2 b/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2
new file mode 100644
index 00000000..45d16c98
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-ExtraLight.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Heavy.woff2 b/web/src/assets/font/MiSans/MiSans-Heavy.woff2
new file mode 100644
index 00000000..09ee22e3
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Heavy.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Light.woff2 b/web/src/assets/font/MiSans/MiSans-Light.woff2
new file mode 100644
index 00000000..a2bb950b
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Light.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Medium.woff2 b/web/src/assets/font/MiSans/MiSans-Medium.woff2
new file mode 100644
index 00000000..617f7407
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Medium.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Normal.woff2 b/web/src/assets/font/MiSans/MiSans-Normal.woff2
new file mode 100644
index 00000000..d24e89dd
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Normal.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Regular.woff2 b/web/src/assets/font/MiSans/MiSans-Regular.woff2
new file mode 100644
index 00000000..6a699b50
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Regular.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Semibold.woff2 b/web/src/assets/font/MiSans/MiSans-Semibold.woff2
new file mode 100644
index 00000000..34f43f7c
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Semibold.woff2 differ
diff --git a/web/src/assets/font/MiSans/MiSans-Thin.woff2 b/web/src/assets/font/MiSans/MiSans-Thin.woff2
new file mode 100644
index 00000000..ec8a3b55
Binary files /dev/null and b/web/src/assets/font/MiSans/MiSans-Thin.woff2 differ
diff --git a/web/src/assets/font/MiSans/index.ts b/web/src/assets/font/MiSans/index.ts
new file mode 100644
index 00000000..e69de29b
diff --git a/web/src/assets/images/CloudUploadOutlined.svg b/web/src/assets/images/CloudUploadOutlined.svg
new file mode 100644
index 00000000..86fdf286
--- /dev/null
+++ b/web/src/assets/images/CloudUploadOutlined.svg
@@ -0,0 +1,26 @@
+
+
\ No newline at end of file
diff --git a/web/src/assets/images/application/arrow_right.svg b/web/src/assets/images/application/arrow_right.svg
new file mode 100644
index 00000000..06400efc
--- /dev/null
+++ b/web/src/assets/images/application/arrow_right.svg
@@ -0,0 +1,17 @@
+
+
\ No newline at end of file
diff --git a/web/src/assets/images/application/clean.svg b/web/src/assets/images/application/clean.svg
index 5d134404..a728abaa 100644
--- a/web/src/assets/images/application/clean.svg
+++ b/web/src/assets/images/application/clean.svg
@@ -1,13 +1,15 @@