Merge branch 'develop' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/app-message-log

This commit is contained in:
wxy
2026-03-26 17:09:05 +08:00
668 changed files with 20166 additions and 11942 deletions

View File

@@ -1,5 +1,6 @@
import os import os
import platform import platform
import re
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
@@ -11,21 +12,24 @@ from app.core.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def _mask_url(url: str) -> str:
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
# macOS fork() safety - must be set before any Celery initialization # macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB由 CELERY_BROKER_DB 指定) # broker: 优先使用环境变量 CELERY_BROKER_URL支持 amqp:// 等任意协议),
# backend: 结果存储(使用 Redis DB由 CELERY_BACKEND_DB 指定) # 未配置则回退到 Redis 方案
# backend: 结果存储(使用 Redis
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND # NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that _broker_url = os.getenv("CELERY_BROKER_URL") or \
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first) f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url
@@ -45,8 +49,8 @@ celery_app = Celery(
logger.info( logger.info(
"Celery app initialized", "Celery app initialized",
extra={ extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"), "broker": _mask_url(_broker_url),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"), "backend": _mask_url(_backend_url),
}, },
) )
# Default queue for unrouted tasks # Default queue for unrouted tasks
@@ -77,6 +81,7 @@ celery_app.conf.update(
# Worker 设置 (per-worker settings are in docker-compose command line) # Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
# 结果过期时间 # 结果过期时间
result_expires=3600, # 结果保存1小时 result_expires=3600, # 结果保存1小时

View File

@@ -57,6 +57,7 @@ def list_apps(
page: int = 1, page: int = 1,
pagesize: int = 10, pagesize: int = 10,
ids: Optional[str] = None, ids: Optional[str] = None,
api_key: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
@@ -65,10 +66,25 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用 - 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页
- 当提供 api_key 参数时,查找该 API Key 关联的应用
""" """
from sqlalchemy import select as sa_select
from app.models.api_key_model import ApiKey
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
service = app_service.AppService(db) service = app_service.AppService(db)
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
if api_key:
matched_id = db.execute(
sa_select(ApiKey.resource_id).where(
ApiKey.workspace_id == workspace_id,
ApiKey.api_key == api_key,
ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
ids = str(matched_id) if matched_id else ""
# 当 ids 存在且不为 None 时,根据 ids 获取应用 # 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None: if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]

View File

@@ -14,6 +14,9 @@ Routes:
import os import os
import uuid import uuid
from typing import Any from typing import Any
import httpx
import mimetypes
from urllib.parse import urlparse, unquote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse from fastapi.responses import FileResponse, RedirectResponse
@@ -290,6 +293,101 @@ async def upload_file_with_share_token(
) )
@router.get("/files/info-by-url", response_model=ApiResponse)
async def get_file_info_by_url(
url: str,
):
"""
Get file information by network URL (no authentication required).
Fetches file metadata from a remote URL via HTTP HEAD request.
Falls back to GET request if HEAD is not supported.
Returns file type, name, and size.
Args:
url: The network URL of the file.
Returns:
ApiResponse with file information.
"""
api_logger.info(f"File info by URL request: url={url}")
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Try HEAD request first
response = await client.head(url, follow_redirects=True)
# If HEAD fails, try GET request (some servers don't support HEAD)
if response.status_code != 200:
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
response = await client.get(url, follow_redirects=True)
if response.status_code != 200:
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access file: HTTP {response.status_code}"
)
# Get file size from Content-Length header or actual content
file_size = response.headers.get("Content-Length")
if file_size:
file_size = int(file_size)
elif hasattr(response, 'content'):
file_size = len(response.content)
else:
file_size = None
# Get content type from Content-Type header
content_type = response.headers.get("Content-Type", "application/octet-stream")
# Remove charset and other parameters from content type
content_type = content_type.split(';')[0].strip()
# Extract filename from Content-Disposition or URL
file_name = None
content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
parts = content_disposition.split("filename=")
if len(parts) > 1:
file_name = parts[1].strip('"').strip("'")
if not file_name:
parsed_url = urlparse(url)
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
# Extract file extension from filename
_, file_ext = os.path.splitext(file_name)
# If no extension found, infer from content type
if not file_ext:
ext = mimetypes.guess_extension(content_type)
if ext:
file_ext = ext
file_name = f"{file_name}{file_ext}"
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
return success(
data={
"url": url,
"file_name": file_name,
"file_ext": file_ext.lower() if file_ext else "",
"file_size": file_size,
"content_type": content_type,
},
msg="File information retrieved successfully"
)
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Unexpected error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file information: {str(e)}"
)
@router.get("/files/{file_id}", response_model=Any) @router.get("/files/{file_id}", response_model=Any)
async def download_file( async def download_file(
request: Request, request: Request,
@@ -476,8 +574,12 @@ async def get_file_url(
# For local storage, generate signed URL with expiration # For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires) url = generate_signed_url(str(file_id), expires)
else: else:
# For remote storage (OSS/S3), get presigned URL # For remote storage (OSS/S3), get presigned URL with forced download
url = await storage_service.get_file_url(file_key, expires=expires) url = await storage_service.get_file_url(
file_key,
expires=expires,
file_name=file_metadata.file_name,
)
url = _match_scheme(request, url) url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}") api_logger.info(f"Generated file URL: file_id={file_id}")
@@ -688,7 +790,7 @@ async def permanent_download_file(
# For remote storage, redirect to presigned URL with long expiration # For remote storage, redirect to presigned URL with long expiration
try: try:
# Use a very long expiration (7 days max for most cloud providers) # Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800) presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
presigned_url = _match_scheme(request, presigned_url) presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:
@@ -697,3 +799,44 @@ async def permanent_download_file(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}" detail=f"Failed to retrieve file: {str(e)}"
) )
@router.get("/files/{file_id}/status", response_model=ApiResponse)
async def get_file_status(
file_id: uuid.UUID,
db: Session = Depends(get_db),
):
"""
Get file upload/processing status (no authentication required).
This endpoint is used to check if a file (e.g., TTS audio) is ready.
Returns status: pending, completed, or failed.
Args:
file_id: The UUID of the file.
db: Database session.
Returns:
ApiResponse with file status and metadata.
"""
api_logger.info(f"File status request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
return success(
data={
"file_id": str(file_id),
"status": file_metadata.status,
"file_name": file_metadata.file_name,
"file_size": file_metadata.file_size,
"content_type": file_metadata.content_type,
},
msg="File status retrieved successfully"
)

View File

@@ -91,9 +91,11 @@ async def get_mcp_servers(
try: try:
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers=api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
url = f'{api.mcp_base_url}/operational' url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
try: try:
cookies = api.get_cookies(access_token=token, cookies_required=True) cookies = api.get_cookies(access_token=token, cookies_required=True)
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
api.login(create_data.token) api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token) cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {create_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
'search': "" 'search': ""
} }
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
api.login(update_data.token) api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token) cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {update_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")

View File

@@ -118,142 +118,142 @@ async def download_log(
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
@router.post("/writer_service", response_model=ApiResponse) # @router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard() # @cur_workspace_access_guard()
async def write_server( # async def write_server(
user_input: Write_UserInput, # user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), # language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), # db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # current_user: User = Depends(get_current_user)
): # ):
""" # """
Write service endpoint - processes write operations synchronously # Write service endpoint - processes write operations synchronously
#
Args: # Args:
user_input: Write request containing message and end_user_id # user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 # language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
Returns: # Returns:
Response with write operation status # Response with write operation status
""" # """
# 使用集中化的语言校验 # # 使用集中化的语言校验
language = get_language_from_header(language_type) # language = get_language_from_header(language_type)
#
config_id = user_input.config_id # config_id = user_input.config_id
workspace_id = current_user.current_workspace_id # workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# 获取 storage_type如果为 None 则使用默认值 # # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( # storage_type = workspace_service.get_workspace_storage_type(
db=db, # db=db,
workspace_id=workspace_id, # workspace_id=workspace_id,
user=current_user # user=current_user
) # )
if storage_type is None: storage_type = 'neo4j' # if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = '' # user_rag_memory_id = ''
#
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id # # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag': # if storage_type == 'rag':
if workspace_id: # if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name( # knowledge = knowledge_repository.get_knowledge_by_name(
db=db, # db=db,
name="USER_RAG_MERORY", # name="USER_RAG_MERORY",
workspace_id=workspace_id # workspace_id=workspace_id
) # )
if knowledge: # if knowledge:
user_rag_memory_id = str(knowledge.id) # user_rag_memory_id = str(knowledge.id)
else: # else:
api_logger.warning( # api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") # f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j' # storage_type = 'neo4j'
else: # else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") # api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' # storage_type = 'neo4j'
#
api_logger.info( # api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") # f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: # try:
messages_list = memory_agent_service.get_messages_list(user_input) # messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( # result = await memory_agent_service.write_memory(
user_input.end_user_id, # user_input.end_user_id,
messages_list, # messages_list,
config_id, # config_id,
db, # db,
storage_type, # storage_type,
user_rag_memory_id, # user_rag_memory_id,
language # language
) # )
#
return success(data=result, msg="写入成功") # return success(data=result, msg="写入成功")
except BaseException as e: # except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'): # if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] # error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages) # detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) # api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) # return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True) # api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) # return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
#
#
@router.post("/writer_service_async", response_model=ApiResponse) # @router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard() # @cur_workspace_access_guard()
async def write_server_async( # async def write_server_async(
user_input: Write_UserInput, # user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), # language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), # db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # current_user: User = Depends(get_current_user)
): # ):
""" # """
Async write service endpoint - enqueues write processing to Celery # Async write service endpoint - enqueues write processing to Celery
#
Args: # Args:
user_input: Write request containing message and end_user_id # user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 # language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
Returns: # Returns:
Task ID for tracking async operation # Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result # Use GET /memory/write_result/{task_id} to check task status and get result
""" # """
# 使用集中化的语言校验 # # 使用集中化的语言校验
language = get_language_from_header(language_type) # language = get_language_from_header(language_type)
#
config_id = user_input.config_id # config_id = user_input.config_id
workspace_id = current_user.current_workspace_id # workspace_id = current_user.current_workspace_id
api_logger.info( # api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# 获取 storage_type如果为 None 则使用默认值 # # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( # storage_type = workspace_service.get_workspace_storage_type(
db=db, # db=db,
workspace_id=workspace_id, # workspace_id=workspace_id,
user=current_user # user=current_user
) # )
if storage_type is None: storage_type = 'neo4j' # if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = '' # user_rag_memory_id = ''
if workspace_id: # if workspace_id:
#
knowledge = knowledge_repository.get_knowledge_by_name( # knowledge = knowledge_repository.get_knowledge_by_name(
db=db, # db=db,
name="USER_RAG_MERORY", # name="USER_RAG_MERORY",
workspace_id=workspace_id # workspace_id=workspace_id
) # )
if knowledge: user_rag_memory_id = str(knowledge.id) # if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") # api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try: # try:
# 获取标准化的消息列表 # # 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) # messages_list = memory_agent_service.get_messages_list(user_input)
#
task = celery_app.send_task( # task = celery_app.send_task(
"app.core.memory.agent.write_message", # "app.core.memory.agent.write_message",
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] # args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
) # )
api_logger.info(f"Write task queued: {task.id}") # api_logger.info(f"Write task queued: {task.id}")
#
return success(data={"task_id": task.id}, msg="写入任务已提交") # return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e: # except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}") # api_logger.error(f"Async write operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) # return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)

View File

@@ -663,9 +663,12 @@ async def dashboard_data(
rag_data["total_memory"] = total_chunk rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量 # total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository # 包含自有app + 被分享给本工作空间的app
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) from app.services import app_service as _app_svc
rag_data["total_app"] = len(apps_orm) _, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
rag_data["total_app"] = total_app
# total_knowledge: 使用 total_kb总知识库数 # total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
@@ -687,7 +690,7 @@ async def dashboard_data(
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}") api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0 rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e: except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}") api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -54,8 +54,8 @@ router = APIRouter(
@router.get("/info", response_model=ApiResponse) @router.get("/info", response_model=ApiResponse)
async def get_storage_info( async def get_storage_info(
storage_id: str, storage_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Example wrapper endpoint - retrieves storage information Example wrapper endpoint - retrieves storage information
@@ -75,17 +75,12 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config( def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
@@ -107,9 +102,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}") api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -119,9 +116,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -129,10 +128,10 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: UUID|int, config_id: UUID | int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""删除记忆配置(带终端用户保护) """删除记忆配置(带终端用户保护)
@@ -145,7 +144,7 @@ def delete_config(
force: 设置为 true 可强制删除(即使有终端用户正在使用) force: 设置为 true 可强制删除(即使有终端用户正在使用)
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
@@ -203,9 +202,9 @@ def delete_config(
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config( def update_config(
payload: ConfigUpdate, payload: ConfigUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -217,7 +216,8 @@ def update_config(
# 校验至少有一个字段需要更新 # 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
"config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
@@ -231,9 +231,9 @@ def update_config(
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted( def update_config_extracted(
payload: ConfigUpdateExtracted, payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -256,11 +256,11 @@ def update_config_extracted(
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: UUID | int, config_id: UUID | int,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -278,10 +278,11 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -303,10 +304,10 @@ def read_all_config(
@router.post("/pilot_run", response_model=None) @router.post("/pilot_run", response_model=None)
async def pilot_run( async def pilot_run(
payload: ConfigPilotRun, payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> StreamingResponse: ) -> StreamingResponse:
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
@@ -333,9 +334,9 @@ async def pilot_run(
@router.get("/search/kb_type_distribution", response_model=ApiResponse) @router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution( async def get_kb_type_distribution(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try: try:
result = await kb_type_distribution(end_user_id) result = await kb_type_distribution(end_user_id)
@@ -347,9 +348,9 @@ async def get_kb_type_distribution(
@router.get("/search/dialogue", response_model=ApiResponse) @router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num( async def search_dialogues_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try: try:
result = await search_dialogue(end_user_id) result = await search_dialogue(end_user_id)
@@ -361,9 +362,9 @@ async def search_dialogues_num(
@router.get("/search/chunk", response_model=ApiResponse) @router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num( async def search_chunks_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try: try:
result = await search_chunk(end_user_id) result = await search_chunk(end_user_id)
@@ -375,9 +376,9 @@ async def search_chunks_num(
@router.get("/search/statement", response_model=ApiResponse) @router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num( async def search_statements_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try: try:
result = await search_statement(end_user_id) result = await search_statement(end_user_id)
@@ -389,9 +390,9 @@ async def search_statements_num(
@router.get("/search/entity", response_model=ApiResponse) @router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num( async def search_entities_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try: try:
result = await search_entity(end_user_id) result = await search_entity(end_user_id)
@@ -403,9 +404,9 @@ async def search_entities_num(
@router.get("/search", response_model=ApiResponse) @router.get("/search", response_model=ApiResponse)
async def search_all_num( async def search_all_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}") api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try: try:
result = await search_all(end_user_id) result = await search_all(end_user_id)
@@ -417,9 +418,9 @@ async def search_all_num(
@router.get("/search/detials", response_model=ApiResponse) @router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials( async def search_entities_detials(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}") api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try: try:
result = await search_detials(end_user_id) result = await search_detials(end_user_id)
@@ -431,9 +432,9 @@ async def search_entities_detials(
@router.get("/search/edges", response_model=ApiResponse) @router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges( async def search_entity_edges(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try: try:
result = await search_edges(end_user_id) result = await search_edges(end_user_id)
@@ -443,14 +444,12 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api( async def get_hot_memory_tags_api(
limit: int = 10, limit: int = 10,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
获取热门记忆标签带Redis缓存 获取热门记忆标签带Redis缓存
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache( async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
清除热门标签缓存 清除热门标签缓存
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))

View File

@@ -42,6 +42,7 @@ def get_model_strategies():
@router.get("", response_model=ApiResponse) @router.get("", response_model=ApiResponse)
def get_model_list( def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"), type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -74,10 +75,21 @@ def get_model_list(
unique_flat_type = list(dict.fromkeys(flat_type)) unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type] type_list = [ModelType(t.lower()) for t in unique_flat_type]
capability_list = []
if capability is not None:
flat_capability = []
for item in capability:
split_items = [c.strip() for c in item.split(', ') if c.strip()]
flat_capability.extend(split_items)
unique_flat_capability = list(dict.fromkeys(flat_capability))
capability_list = unique_flat_capability
api_logger.error(f"获取模型type_list: {type_list}") api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery( query = model_schema.ModelConfigQuery(
type=type_list, type=type_list,
provider=provider, provider=provider,
capability=capability_list,
is_active=is_active, is_active=is_active,
is_public=is_public, is_public=is_public,
search=search, search=search,

View File

@@ -669,6 +669,7 @@ async def config_query(
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables"), "variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features") "features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:

View File

@@ -5,7 +5,7 @@
from typing import Optional from typing import Optional
import datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends,Header from fastapi import APIRouter, Depends, Header
from app.db import get_db from app.db import get_db
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
analytics_graph_data, analytics_graph_data,
analytics_community_graph_data, analytics_community_graph_data,
) )
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository from app.repositories.workspace_repository import WorkspaceRepository
@@ -45,9 +45,9 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse) @router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api( async def get_memory_insight_report_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的记忆洞察报告 获取缓存的记忆洞察报告
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse) @router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api( async def get_user_summary_api(
end_user_id: str, end_user_id: str,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的用户摘要 获取缓存的用户摘要
@@ -102,7 +102,7 @@ async def get_user_summary_api(
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try: try:
# 调用服务层获取缓存数据 # 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
if result["is_cached"]: if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -117,10 +117,10 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse) @router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api( async def generate_cache_api(
request: GenerateCacheRequest, request: GenerateCacheRequest,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
手动触发缓存生成 手动触发缓存生成
@@ -155,10 +155,12 @@ async def generate_cache_api(
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察 # 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
language=language)
# 生成用户摘要 # 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
language=language)
# 构建响应 # 构建响应
result = { result = {
@@ -209,9 +211,9 @@ async def generate_cache_api(
@router.get("/analytics/node_statistics", response_model=ApiResponse) @router.get("/analytics/node_statistics", response_model=ApiResponse)
async def get_node_statistics_api( async def get_node_statistics_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") api_logger.info(
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try: try:
# 调用新的记忆类型统计函数 # 调用新的记忆类型统计函数
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
# 计算总数用于日志 # 计算总数用于日志
total_count = sum(item["count"] for item in result) total_count = sum(item["count"] for item in result)
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") api_logger.info(
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
@router.get("/analytics/graph_data", response_model=ApiResponse) @router.get("/analytics/graph_data", response_model=ApiResponse)
async def get_graph_data_api( async def get_graph_data_api(
end_user_id: str, end_user_id: str,
node_types: Optional[str] = None, node_types: Optional[str] = None,
limit: int = 100, limit: int = 100,
depth: int = 1, depth: int = 1,
center_node_id: Optional[str] = None, center_node_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -298,9 +303,9 @@ async def get_graph_data_api(
@router.get("/analytics/community_graph", response_model=ApiResponse) @router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api( async def get_community_graph_data_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
@router.get("/read_end_user/profile", response_model=ApiResponse) @router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_profile( async def get_end_user_profile(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
@@ -385,9 +390,9 @@ async def get_end_user_profile(
@router.post("/updated_end_user/profile", response_model=ApiResponse) @router.post("/updated_end_user/profile", response_model=ApiResponse)
async def update_end_user_profile( async def update_end_user_profile(
profile_update: EndUserProfileUpdate, profile_update: EndUserProfileUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
更新终端用户的基本信息 更新终端用户的基本信息
@@ -427,15 +432,18 @@ async def update_end_user_profile(
# 只有未预期的错误才使用 INTERNAL_ERROR # 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
@router.get("/memory_space/timeline_memories", response_model=ApiResponse) @router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), async def memory_space_timeline_of_shared_memories(
current_user: User = Depends(get_current_user), id: str, label: str,
db: Session = Depends(get_db), language_type: str = Header(default=None, alias="X-Language-Type"),
): current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
workspace_id=current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
return success(data=timeline_memories_result, msg="共同记忆时间线") return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse) @router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str, async def memory_space_relationship_evolution(id: str, label: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
try: try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")

View File

@@ -598,8 +598,10 @@ class LangChainAgent:
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", total_tokens = response_meta.get("token_usage", {}).get(
0) if response_meta else 0 "total_tokens",
0
) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:

View File

@@ -231,8 +231,8 @@ class Settings:
# Celery configuration (internal) # Celery configuration (internal)
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md # 详见 docs/celery-env-bug-report.md
# 默认使用 Redis DB 3 (broker)DB 4 (backend),与业务缓存 (DB 1/2) 隔离 # 默认使用 Redis 作为 broker 和 backend与业务缓存隔离
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰 # 如需使用 RabbitMQ在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))

View File

@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
# Fallback to console only if file write fails # Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}") print(f"Warning: Could not write to timing log: {e}")
# Always print to console (backward compatible behavior) # Always log at INFO level (avoids Celery treating stdout as WARNING)
print(f"{step_name}: {duration:.2f}s") _timing_logger = logging.getLogger(__name__)
_timing_logger.info(f"{step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service", def get_agent_logger(name: str = "agent_service",

View File

@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages) formatted_messages = redis_messages
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id

View File

@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
messages: list = None, messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "",
config_id: str = None config_id: str = None
) -> List[DialogData]: ) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy. """Generate chunks from structured messages using the specified chunker strategy.
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
role = msg['role'] role = msg['role']
content = msg['content'] content = msg['content']
files = msg.get("file_content", [])
if role not in ['user', 'assistant']: if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip(): if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
if not conversation_messages: if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering") raise ValueError("Message list cannot be empty after filtering")

View File

@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio import asyncio
import time import time
import uuid
from datetime import datetime from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -13,28 +14,28 @@ from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
load_dotenv() load_dotenv()
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write( async def write(
end_user_id: str, end_user_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", ref_id: str = "",
language: str = "zh", language: str = "zh",
) -> None: ) -> None:
""" """
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
@@ -43,9 +44,11 @@ async def write(
end_user_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to ""
language: 语言类型 ("zh" 中文, "en" 英文),默认中文 language: 语言类型 ("zh" 中文, "en" 英文),默认中文
""" """
if not ref_id:
ref_id = uuid.uuid4().hex
# Extract config values # Extract config values
embedding_model_id = str(memory_config.embedding_model_id) embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy chunker_strategy = memory_config.chunker_strategy
@@ -135,9 +138,11 @@ async def write(
all_chunk_nodes, all_chunk_nodes,
all_statement_nodes, all_statement_nodes,
all_entity_nodes, all_entity_nodes,
all_perceptual_nodes,
all_statement_chunk_edges, all_statement_chunk_edges,
all_statement_entity_edges, all_statement_entity_edges,
all_entity_entity_edges, all_entity_entity_edges,
all_perceptual_edges,
all_dedup_details, all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
@@ -162,18 +167,21 @@ async def write(
chunk_nodes=all_chunk_nodes, chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes, statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes, entity_nodes=all_entity_nodes,
perceptual_nodes=all_perceptual_nodes,
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
perceptual_edges=all_perceptual_edges,
connector=neo4j_connector, connector=neo4j_connector,
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应 # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突
schedule_clustering_after_write( await _trigger_clustering_sync(
all_entity_nodes, all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, embedding_model_id=str(
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
) )
break break
else: else:
@@ -208,9 +216,8 @@ async def write(
summaries = await memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
) )
ms_connector = Neo4jConnector()
try: try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector)
finally: finally:

View File

@@ -1,10 +1,10 @@
from typing import Any, List
import re
import os
import asyncio import asyncio
import json import json
import numpy as np
import logging import logging
import os
from typing import Any, List
import numpy as np
# Fix tokenizer parallelism warning # Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -246,6 +246,7 @@ class ChunkerClient:
"total_sub_chunks": len(sub_chunks), "total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)
else: else:
@@ -258,6 +259,7 @@ class ChunkerClient:
"message_role": msg.role, "message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)

View File

@@ -2,6 +2,7 @@
OpenAI Embedder 客户端实现 OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
自动支持火山引擎的多模态 Embedding。
""" """
from typing import List from typing import List
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
) )
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings from app.core.models.embedding import RedBearEmbeddings
from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
- 批量文本嵌入 - 批量文本嵌入
- 自动重试机制 - 自动重试机制
- 错误处理 - 错误处理
- 火山引擎多模态 Embedding自动识别
""" """
def __init__(self, model_config: RedBearModelConfig): def __init__(self, model_config: RedBearModelConfig):
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
""" """
super().__init__(model_config) super().__init__(model_config)
# 初始化 RedBearEmbeddings 模型 # 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
self.model = RedBearEmbeddings( self.model = RedBearEmbeddings(
RedBearModelConfig( RedBearModelConfig(
model_name=self.model_name, model_name=self.model_name,
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
timeout=self.timeout, timeout=self.timeout,
) )
) )
self.is_multimodal = self.model.is_multimodal_supported()
logger.info("OpenAI Embedder 客户端初始化完成") logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
async def response( async def response(
self, self,
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
return [] return []
# 生成嵌入向量 # 生成嵌入向量
embeddings = await self.model.aembed_documents(texts) if self.is_multimodal:
# 火山引擎多模态 Embedding
embeddings = await self.model.aembed_multimodal(
[{"type": "text", "text": text} for text in texts]
)
else:
# 普通 Embedding
embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings

View File

@@ -114,7 +114,7 @@ class Edge(BaseModel):
end_user_id: str = Field(..., description="The end user ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge): class ChunkEdge(Edge):
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
return parse_historical_datetime(v) return parse_historical_datetime(v)
class PerceptualEdge(Edge):
"""Edge connecting perceptual nodes to their source chunks
"""
pass
class Node(BaseModel): class Node(BaseModel):
"""Base class for all graph nodes in the knowledge graph. """Base class for all graph nodes in the knowledge graph.
@@ -206,7 +212,8 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog") ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content") content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node): class StatementNode(Node):
@@ -281,7 +288,8 @@ class StatementNode(Node):
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -416,7 +424,8 @@ class ExtractedEntityNode(Node):
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity") # fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
@@ -453,7 +462,7 @@ class ExtractedEntityNode(Node):
@field_validator('aliases', mode='before') @field_validator('aliases', mode='before')
@classmethod @classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
"""Validate and clean aliases field using utility function. """Validate and clean aliases field using utility function.
This validator ensures that the aliases field is always a valid list of strings. This validator ensures that the aliases field is always a valid list of strings.
@@ -507,7 +516,8 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties # ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field( original_statement_id: Optional[str] = Field(
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)" description="Total number of times this node has been accessed (reset to 1 on creation)"
) )
class PerceptualNode(Node):
"""Node representing a multimodal message in the knowledge graph.
"""
perceptual_type: int
file_path: str
file_name: str
file_ext: str
summary: str
keywords: list[str]
topic: str
domain: str
file_type: str
summary_embedding: list[float] | None

View File

@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
""" """
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.") msg: str = Field(..., description="The text content of the message.")
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
class TemporalValidityRange(BaseModel): class TemporalValidityRange(BaseModel):
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
content: str = Field(..., description="The content of the chunk as a string.") content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod @classmethod

View File

@@ -71,13 +71,11 @@ class LabelPropagationEngine:
connector: Neo4jConnector, connector: Neo4jConnector,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
): ):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.llm_model_id = llm_model_id self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id self.embedding_model_id = embedding_model_id
self.embedding_model_id = embedding_model_id
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
await self._generate_community_metadata([new_cid], end_user_id)
return return
# 统计邻居社区分布 # 统计邻居社区分布
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
await self._evaluate_merge( await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id list(community_ids_in_neighbors), end_user_id
) )
await self._generate_community_metadata([target_cid], end_user_id) # 新实体加入后成员变化,强制重新生成元数据
await self._generate_community_metadata([target_cid], end_user_id, force=True)
async def _evaluate_merge( async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str self, community_ids: List[str], end_user_id: str
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
return lines return lines
async def _generate_community_metadata( async def _generate_community_metadata(
self, community_ids: List[str], end_user_id: str self, community_ids: List[str], end_user_id: str, force: bool = False
) -> None: ) -> None:
""" """
为一个或多个社区生成并写入元数据。 为一个或多个社区生成并写入元数据。
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
1. 逐个社区调 LLM 生成 name / summary串行 1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed 2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata 3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
"""
if not community_ids:
return
Args:
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
"""
from app.db import get_db_context from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# --- 阶段1并发调 LLM 生成每个社区的 name / summary --- async def _build_one(cid: str) -> Optional[Dict]:
async def _build_one(cid: str): try:
members = await self.repo.get_community_members(cid, end_user_id) if not force:
if not members: check_embedding = bool(self.embedding_model_id)
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
return None
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
return None
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else cid[:8]
summary = f"包含实体:{', '.join(all_names)}"
if self.llm_model_id:
try:
entity_list_str = "\n".join(self._build_entity_lines(members))
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
except Exception as e:
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
return None return None
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
entity_list_str = "\n".join(self._build_entity_lines(members))
# 方案四:注入社区内实体间关系三元组
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
name, summary = "", ""
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
results = await asyncio.gather( results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids], *[_build_one(cid) for cid in community_ids],
return_exceptions=True, return_exceptions=True,
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
metadata_list.append(res) metadata_list.append(res)
if not metadata_list: if not metadata_list:
logger.warning(f"[Clustering] 无有效元数据可写入community_ids={community_ids}")
return return
# --- 阶段2批量生成 summary_embedding --- # --- 阶段2批量生成 summary_embedding ---
summaries = [m["summary"] for m in metadata_list] if self.embedding_model_id:
with get_db_context() as db: try:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) summaries = [m["summary"] for m in metadata_list]
embeddings = await embedder.response(summaries) with get_db_context() as db:
for i, meta in enumerate(metadata_list): embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
except Exception as e:
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
# --- 阶段3写入单个 or 批量)--- # --- 阶段3写入单个 or 批量)---
if len(metadata_list) == 1: if len(metadata_list) == 1:
@@ -558,16 +576,12 @@ class LabelPropagationEngine:
core_entities=m["core_entities"], core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"], summary_embedding=m["summary_embedding"],
) )
if result: if not result:
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
else:
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
else: else:
ok = await self.repo.batch_update_community_metadata(metadata_list) ok = await self.repo.batch_update_community_metadata(metadata_list)
if ok: if not ok:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
else:
logger.warning(f"[Clustering] 批量写入社区元数据失败")
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:

View File

@@ -9,6 +9,7 @@
""" """
import asyncio import asyncio
import logging
import os import os
import hashlib import hashlib
import json import json
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
ScenePatterns ScenePatterns
) )
logger = logging.getLogger(__name__)
class DialogExtractionResponse(BaseModel): class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。 """对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -706,7 +709,7 @@ class SemanticPruner:
# 阈值保护最高0.9 # 阈值保护最高0.9
proportion = float(self.config.pruning_threshold) proportion = float(self.config.pruning_threshold)
if proportion > 0.9: if proportion > 0.9:
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9") logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9")
proportion = 0.9 proportion = 0.9
if proportion < 0.0: if proportion < 0.0:
proportion = 0.0 proportion = 0.0
@@ -905,7 +908,7 @@ class SemanticPruner:
# Safety: avoid empty dataset # Safety: avoid empty dataset
if not result: if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs return dialogs
return result return result
@@ -915,8 +918,7 @@ class SemanticPruner:
try: try:
self.run_logs.append(msg) self.run_logs.append(msg)
except Exception: except Exception:
# 任何异常都不影响打印
pass pass
print(msg) logger.debug(msg)

View File

@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return( async def dedup_layers_and_merge_and_return(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
pipeline_config: ExtractionPipelineConfig, pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None, connector: Optional[Neo4jConnector] = None,
llm_client = None, llm_client=None,
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
dict, # 新增:返回去重详情 dict
]: ]:
""" """
执行两层实体去重与融合: 执行两层实体去重与融合:

View File

@@ -32,10 +32,11 @@ from app.core.memory.models.graph_models import (
StatementChunkEdge, StatementChunkEdge,
StatementEntityEdge, StatementEntityEdge,
StatementNode, StatementNode,
PerceptualEdge,
PerceptualNode
) )
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.variate_config import ( from app.core.memory.models.variate_config import (
ExtractionPipelineConfig, ExtractionPipelineConfig,
) )
@@ -46,7 +47,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
embedding_generation, embedding_generation,
generate_entity_embeddings_from_triplets, generate_entity_embeddings_from_triplets,
) )
# 导入各个提取模块 # 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor, StatementExtractor,
@@ -90,16 +90,16 @@ class ExtractionOrchestrator:
""" """
def __init__( def __init__(
self, self,
llm_client: LLMClient, llm_client: LLMClient,
embedder_client: OpenAIEmbedderClient, embedder_client: OpenAIEmbedderClient,
connector: Neo4jConnector, connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None, config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None, embedding_id: Optional[str] = None,
ontology_types: Optional[OntologyTypeList] = None, ontology_types: Optional[OntologyTypeList] = None,
enable_general_types: bool = True, enable_general_types: bool = True,
language: str = "zh", language: str = "zh",
): ):
""" """
初始化流水线编排器 初始化流水线编排器
@@ -157,19 +157,27 @@ class ExtractionOrchestrator:
llm_client=llm_client, llm_client=llm_client,
config=self.config.statement_extraction, config=self.config.statement_extraction,
) )
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
language=language)
self.temporal_extractor = TemporalExtractor(llm_client=llm_client) self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
logger.info("ExtractionOrchestrator 初始化完成") logger.info("ExtractionOrchestrator 初始化完成")
async def run( async def run(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
is_pilot_run: bool = False, is_pilot_run: bool = False,
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[PerceptualNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
list[PerceptualEdge],
dict
]: ]:
""" """
运行完整的知识提取流水线(优化版:并行执行) 运行完整的知识提取流水线(优化版:并行执行)
@@ -208,7 +216,6 @@ class ExtractionOrchestrator:
for dialog in dialog_data_list: for dialog in dialog_data_list:
for chunk in dialog.chunks: for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements) all_statements_list.extend(chunk.statements)
len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -230,10 +237,6 @@ class ExtractionOrchestrator:
all_entities_list.extend(triplet_info.entities) all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets) all_triplets_list.extend(triplet_info.triplets)
len(all_entities_list)
len(all_triplets_list)
sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果) # 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入") logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps) triplet_maps = await self._generate_entity_embeddings(triplet_maps)
@@ -260,9 +263,11 @@ class ExtractionOrchestrator:
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
entity_nodes, entity_nodes,
perceptual_nodes,
statement_chunk_edges, statement_chunk_edges,
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
perceptual_edges
) = await self._create_nodes_and_edges(dialog_data_list) ) = await self._create_nodes_and_edges(dialog_data_list)
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
@@ -276,7 +281,16 @@ class ExtractionOrchestrator:
# 注意deduplication 消息已在创建节点和边完成后立即发送 # 注意deduplication 消息已在创建节点和边完成后立即发送
result = await self._run_dedup_and_write_summary( (
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
dialog_data_list,
) = await self._run_dedup_and_write_summary(
dialogue_nodes, dialogue_nodes,
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
@@ -287,17 +301,26 @@ class ExtractionOrchestrator:
dialog_data_list, dialog_data_list,
) )
logger.info(f"知识提取流水线运行完成({mode_str}") logger.info(f"知识提取流水线运行完成({mode_str}")
return result return (
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
perceptual_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
perceptual_edges,
dialog_data_list,
)
except Exception as e: except Exception as e:
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
raise raise
async def _extract_statements( async def _extract_statements(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[DialogData]: ) -> List[DialogData]:
""" """
从对话中提取陈述句(流式输出版本:边提取边发送进度) 从对话中提取陈述句(流式输出版本:边提取边发送进度)
@@ -395,7 +418,7 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _extract_triplets( async def _extract_triplets(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取三元组(流式输出版本:边提取边发送进度) 从对话中提取三元组(流式输出版本:边提取边发送进度)
@@ -478,7 +501,7 @@ class ExtractionOrchestrator:
return triplet_maps return triplet_maps
async def _extract_temporal( async def _extract_temporal(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取时间信息(流式输出版本:边提取边发送进度) 从对话中提取时间信息(流式输出版本:边提取边发送进度)
@@ -585,7 +608,7 @@ class ExtractionOrchestrator:
return temporal_maps return temporal_maps
async def _extract_emotions( async def _extract_emotions(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
@@ -706,7 +729,7 @@ class ExtractionOrchestrator:
return emotion_maps return emotion_maps
async def _parallel_extract_and_embed( async def _parallel_extract_and_embed(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[Dict[str, Any]], List[Dict[str, Any]],
List[Dict[str, Any]], List[Dict[str, Any]],
@@ -777,7 +800,7 @@ class ExtractionOrchestrator:
) )
async def _generate_basic_embeddings( async def _generate_basic_embeddings(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
""" """
生成基础嵌入向量(陈述句、分块、对话) 生成基础嵌入向量(陈述句、分块、对话)
@@ -836,7 +859,7 @@ class ExtractionOrchestrator:
) )
async def _generate_entity_embeddings( async def _generate_entity_embeddings(
self, triplet_maps: List[Dict[str, Any]] self, triplet_maps: List[Dict[str, Any]]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
生成实体嵌入向量 生成实体嵌入向量
@@ -874,17 +897,15 @@ class ExtractionOrchestrator:
logger.error(f"实体嵌入生成失败: {e}", exc_info=True) logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
return triplet_maps return triplet_maps
async def _assign_extracted_data( async def _assign_extracted_data(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
temporal_maps: List[Dict[str, Any]], temporal_maps: List[Dict[str, Any]],
triplet_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]],
emotion_maps: List[Dict[str, Any]], emotion_maps: List[Dict[str, Any]],
statement_embedding_maps: List[Dict[str, List[float]]], statement_embedding_maps: List[Dict[str, List[float]]],
chunk_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]],
dialog_embeddings: List[List[float]], dialog_embeddings: List[List[float]],
) -> List[DialogData]: ) -> List[DialogData]:
""" """
将提取的数据赋值到语句 将提取的数据赋值到语句
@@ -906,12 +927,12 @@ class ExtractionOrchestrator:
# 确保列表长度匹配 # 确保列表长度匹配
expected_length = len(dialog_data_list) expected_length = len(dialog_data_list)
if ( if (
len(temporal_maps) != expected_length len(temporal_maps) != expected_length
or len(triplet_maps) != expected_length or len(triplet_maps) != expected_length
or len(emotion_maps) != expected_length or len(emotion_maps) != expected_length
or len(statement_embedding_maps) != expected_length or len(statement_embedding_maps) != expected_length
or len(chunk_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length
or len(dialog_embeddings) != expected_length or len(dialog_embeddings) != expected_length
): ):
logger.warning( logger.warning(
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
@@ -999,15 +1020,17 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _create_nodes_and_edges( async def _create_nodes_and_edges(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
List[StatementNode], List[StatementNode],
List[ExtractedEntityNode], List[ExtractedEntityNode],
List[PerceptualNode],
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
List[PerceptualEdge]
]: ]:
""" """
创建图数据库节点和边 创建图数据库节点和边
@@ -1031,6 +1054,8 @@ class ExtractionOrchestrator:
statement_chunk_edges = [] statement_chunk_edges = []
statement_entity_edges = [] statement_entity_edges = []
entity_entity_edges = [] entity_entity_edges = []
perceptual_nodes = []
perceptual_edges = []
# 用于去重的集合 # 用于去重的集合
entity_id_set = set() entity_id_set = set()
@@ -1075,6 +1100,45 @@ class ExtractionOrchestrator:
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
for p, file_type in chunk.files:
meta = p.meta_data or {}
content_meta = meta.get("content", {})
# 生成 summary embedding如果有 embedder_client
summary_embedding = None
if self.embedder_client and p.summary:
try:
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
except Exception as emb_err:
print(f"Failed to embed perceptual summary: {emb_err}")
perceptual = PerceptualNode(
name=f"Perceptual_{p.id}",
**{
"id": str(p.id),
"end_user_id": str(p.end_user_id),
"perceptual_type": p.perceptual_type,
"file_path": p.file_path or "",
"file_name": p.file_name or "",
"file_ext": p.file_ext or "",
"summary": p.summary or "",
"keywords": content_meta.get("keywords", []),
"topic": content_meta.get("topic", ""),
"domain": content_meta.get("domain", ""),
"created_at": p.created_time.isoformat() if p.created_time else None,
"file_type": file_type,
"summary_embedding": summary_embedding,
})
perceptual_nodes.append(perceptual)
perceptual_edges.append(PerceptualEdge(
source=perceptual.id,
target=chunk.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
))
# 处理每个陈述句 # 处理每个陈述句
for statement in chunk.statements: for statement in chunk.statements:
# 创建陈述句节点 # 创建陈述句节点
@@ -1083,15 +1147,19 @@ class ExtractionOrchestrator:
name=f"Statement_{statement.id}", # 添加必需的 name 字段 name=f"Statement_{statement.id}", # 添加必需的 name 字段
chunk_id=chunk.id, chunk_id=chunk.id,
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding, statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, valid_at=statement.temporal_validity.valid_at if hasattr(statement,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
'temporal_validity') and statement.temporal_validity else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
@@ -1141,7 +1209,8 @@ class ExtractionOrchestrator:
example=getattr(entity, 'example', ''), # 新增:传递示例字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
@@ -1248,25 +1317,32 @@ class ExtractionOrchestrator:
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
entity_nodes, entity_nodes,
perceptual_nodes,
statement_chunk_edges, statement_chunk_edges,
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
perceptual_edges
) )
async def _run_dedup_and_write_summary( async def _run_dedup_and_write_summary(
self, self,
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
dict
]: ]:
""" """
执行两阶段去重并写入汇总 执行两阶段去重并写入汇总
@@ -1415,7 +1491,6 @@ class ExtractionOrchestrator:
len(entity_entity_edges), len(final_entity_entity_edges) len(entity_entity_edges), len(final_entity_entity_edges)
) )
# 写入提取结果汇总(试运行和正式模式都需要生成) # 写入提取结果汇总(试运行和正式模式都需要生成)
try: try:
from app.core.config import settings from app.core.config import settings
@@ -1436,10 +1511,10 @@ class ExtractionOrchestrator:
raise raise
def _save_dedup_details( def _save_dedup_details(
self, self,
dedup_details: Dict[str, Any], dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
): ):
""" """
保存去重消歧的详细记录到实例变量(基于内存数据结构) 保存去重消歧的详细记录到实例变量(基于内存数据结构)
@@ -1537,15 +1612,16 @@ class ExtractionOrchestrator:
except Exception as e: except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") logger.info(
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
except Exception as e: except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges( async def _analyze_entity_merges(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
@@ -1585,9 +1661,9 @@ class ExtractionOrchestrator:
return [] return []
async def _analyze_entity_disambiguation( async def _analyze_entity_disambiguation(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
@@ -1645,9 +1721,9 @@ class ExtractionOrchestrator:
return type_mapping.get(entity_type, f"{entity_type}实体节点") return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results( async def _output_relationship_creation_results(
self, self,
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
): ):
""" """
输出关系创建结果 输出关系创建结果
@@ -1681,13 +1757,13 @@ class ExtractionOrchestrator:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True) logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback( async def _send_dedup_progress_callback(
self, self,
original_entities: int, original_entities: int,
final_entities: int, final_entities: int,
original_stmt_edges: int, original_stmt_edges: int,
final_stmt_edges: int, final_stmt_edges: int,
original_ent_edges: int, original_ent_edges: int,
final_ent_edges: int, final_ent_edges: int,
): ):
""" """
发送去重消歧完成的进度回调,传递具体的去重和消歧效果 发送去重消歧完成的进度回调,传递具体的去重和消歧效果
@@ -1715,7 +1791,8 @@ class ExtractionOrchestrator:
"original_count": original_entities, "original_count": original_entities,
"final_count": final_entities, "final_count": final_entities,
"reduced_count": entities_reduced, "reduced_count": entities_reduced,
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, "reduction_rate": round(entities_reduced / original_entities * 100,
1) if original_entities > 0 else 0,
}, },
"statement_entity_edges": { "statement_entity_edges": {
"original_count": original_stmt_edges, "original_count": original_stmt_edges,
@@ -1790,7 +1867,8 @@ class ExtractionOrchestrator:
disamb_examples.append({ disamb_examples.append({
"entity1_name": entity_name, "entity1_name": entity_name,
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
"").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name, "entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功" "description": f"{entity_name},消歧区分成功"
@@ -1815,9 +1893,9 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
@@ -1924,10 +2002,10 @@ async def get_chunked_dialogs(
def preprocess_data( def preprocess_data(
input_path: Optional[str] = None, input_path: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
indices: Optional[List[int]] = None indices: Optional[List[int]] = None
) -> List[DialogData]: ) -> List[DialogData]:
"""数据预处理 """数据预处理
@@ -1946,7 +2024,8 @@ def preprocess_data(
) )
preprocessor = DataPreprocessor() preprocessor = DataPreprocessor()
try: try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
skip_cleaning=skip_cleaning, indices=indices)
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data return cleaned_data
except Exception as e: except Exception as e:
@@ -1955,9 +2034,9 @@ def preprocess_data(
async def get_chunked_dialogs_from_preprocessed( async def get_chunked_dialogs_from_preprocessed(
data: List[DialogData], data: List[DialogData],
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从预处理后的数据中生成分块 """从预处理后的数据中生成分块
@@ -1988,15 +2067,15 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
input_data_path: Optional[str] = None, input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None, pruning_config: Optional[Dict] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程 """包含数据预处理步骤的完整分块流程
@@ -2046,7 +2125,8 @@ async def get_chunked_dialogs_with_preprocessing(
if pruning_config: if pruning_config:
# 使用传入的配置 # 使用传入的配置
config = PruningConfig(**pruning_config) config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") logger.debug(
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else: else:
# 使用默认配置(关闭剪枝) # 使用默认配置(关闭剪枝)
config = None config = None

View File

@@ -5,8 +5,11 @@
""" """
import asyncio import asyncio
import logging
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
return await self.embedder_client.response(texts) return await self.embedder_client.response(texts)
# 分批并行处理 # 分批并行处理
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
# 并行发送所有批次 # 并行发送所有批次
batch_results = await asyncio.gather(*[ batch_results = await asyncio.gather(*[
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
for batch_result in batch_results: for batch_result in batch_results:
embeddings.extend(batch_result) embeddings.extend(batch_result)
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings
async def generate_statement_embeddings( async def generate_statement_embeddings(
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
Returns: Returns:
每个对话的陈述句嵌入向量映射列表 每个对话的陈述句嵌入向量映射列表
""" """
print("\n=== 生成陈述句嵌入向量 ===") logger.debug("=== 生成陈述句嵌入向量 ===")
# 收集所有陈述句 # 收集所有陈述句
all_statements = [] all_statements = []
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
stmt_embedding_maps[d_idx][stmt_id] = embedding stmt_embedding_maps[d_idx][stmt_id] = embedding
print(f"{len(all_statements)} 个陈述句生成了嵌入向量") logger.info(f"{len(all_statements)} 个陈述句生成了嵌入向量")
return stmt_embedding_maps return stmt_embedding_maps
async def generate_chunk_embeddings( async def generate_chunk_embeddings(
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
Returns: Returns:
每个对话的分块嵌入向量映射列表 每个对话的分块嵌入向量映射列表
""" """
print("\n=== 生成分块嵌入向量 ===") logger.debug("=== 生成分块嵌入向量 ===")
# 收集所有分块 # 收集所有分块
all_chunks = [] all_chunks = []
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
chunk_embedding_maps[d_idx][chunk_id] = embedding chunk_embedding_maps[d_idx][chunk_id] = embedding
print(f"{len(all_chunks)} 个分块生成了嵌入向量") logger.info(f"{len(all_chunks)} 个分块生成了嵌入向量")
return chunk_embedding_maps return chunk_embedding_maps
async def generate_dialog_embeddings( async def generate_dialog_embeddings(
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
Returns: Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
""" """
print("\n=== 生成所有嵌入向量 ===") logger.debug("=== 生成所有嵌入向量 ===")
# 并发生成陈述句和分块嵌入向量 # 并发生成陈述句和分块嵌入向量
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
# 对话嵌入向量(当前跳过) # 对话嵌入向量(当前跳过)
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
print( logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
)
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
Returns: Returns:
更新后的三元组映射列表(实体包含嵌入向量) 更新后的三元组映射列表(实体包含嵌入向量)
""" """
print("\n=== 生成实体嵌入向量 ===") logger.debug("=== 生成实体嵌入向量 ===")
entity_texts: List[str] = [] entity_texts: List[str] = []
entity_refs: List[Any] = [] entity_refs: List[Any] = []
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
entity_refs.append(ent) entity_refs.append(ent)
if not entity_texts: if not entity_texts:
print("没有找到需要生成嵌入向量的实体") logger.debug("没有找到需要生成嵌入向量的实体")
return triplet_maps return triplet_maps
# 批量生成嵌入向量 # 批量生成嵌入向量
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度 # 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))): for i in range(min(5, len(embeddings))):
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体 # 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings): for ent, emb in zip(entity_refs, embeddings):
setattr(ent, "name_embedding", emb) setattr(ent, "name_embedding", emb)
print(f"{len(entity_refs)} 个实体生成了嵌入向量") logger.info(f"{len(entity_refs)} 个实体生成了嵌入向量")
return triplet_maps return triplet_maps
@@ -296,7 +297,7 @@ async def embedding_generation_all(
Returns: Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
""" """
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
generator = EmbeddingGenerator(embedding_id) generator = EmbeddingGenerator(embedding_id)

View File

@@ -188,7 +188,6 @@ async def _process_chunk_summary(
response_model=MemorySummaryResponse, response_model=MemorySummaryResponse,
) )
summary_text = structured.summary.strip() summary_text = structured.summary.strip()
# Generate title and type for the summary # Generate title and type for the summary
title = None title = None
episodic_type = None episodic_type = None

View File

@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
from .llm import RedBearLLM from .llm import RedBearLLM
from .embedding import RedBearEmbeddings from .embedding import RedBearEmbeddings
from .rerank import RedBearRerank from .rerank import RedBearRerank
from .generation import RedBearImageGenerator, RedBearVideoGenerator
__all__ = [ __all__ = [
"RedBearModelConfig", "RedBearModelConfig",
@@ -9,5 +10,7 @@ __all__ = [
"RedBearEmbeddings", "RedBearEmbeddings",
"RedBearRerank", "RedBearRerank",
"RedBearModelFactory", "RedBearModelFactory",
"get_provider_llm_class" "get_provider_llm_class",
"RedBearImageGenerator",
"RedBearVideoGenerator"
] ]

View File

@@ -67,7 +67,7 @@ class RedBearModelFactory:
**config.extra_params **config.extra_params
} }
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置 # 使用 httpx.Timeout 对象来设置详细的超时配置
# 这样可以分别控制连接超时和读取超时 # 这样可以分别控制连接超时和读取超时
import httpx import httpx
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if type == ModelType.LLM: if type == ModelType.LLM:
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
return ChatOpenAI return ChatOpenAI
else:
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:

View File

@@ -1,23 +1,190 @@
from typing import Any, Dict, List, Optional, TypeVar, Callable from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
from app.models.models_model import ModelProvider
class RedBearEmbeddings(Embeddings): class RedBearEmbeddings(Embeddings):
"""Embedding → 完全符合 LangChain Embeddings""" """统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
def __init__(self, config: RedBearModelConfig): def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config self._config = config
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
if self._is_volcano:
# 火山引擎使用 Ark SDK
self._client = self._create_volcano_client(config)
self._model = None
else:
# 其他 provider 使用 LangChain
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings: def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建模型""" """根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider) embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config) model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params) return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""
from volcenginesdkarkruntime import Ark
return Ark(api_key=config.api_key, base_url=config.base_url)
# ==================== LangChain 标准接口 ====================
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._model.embed_documents(texts) """批量文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
contents = [{"type": "text", "text": text} for text in texts]
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
encoding_format="float"
)
return [response.data.embedding]
else:
# 其他 provider
return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
return self._model.embed_query(text) """单个文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
result = self.embed_documents([text])
return result[0] if result else []
else:
# 其他 provider
return self._model.embed_query(text)
# ==================== 多模态扩展方法 ====================
def embed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""
多模态向量化(仅火山引擎支持)
Args:
contents: 内容列表,格式:
- 文本: {"type": "text", "text": "..."}
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
**kwargs: 其他参数
Returns:
向量列表
"""
if not self._is_volcano:
raise NotImplementedError(
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
)
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
**kwargs
)
return [response.data.embedding]
async def aembed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""异步多模态向量化"""
# 火山引擎 SDK 暂不支持异步,使用同步方法
return self.embed_multimodal(contents, **kwargs)
def embed_text(self, text: str, **kwargs) -> List[float]:
"""文本向量化(便捷方法)"""
if self._is_volcano:
result = self.embed_multimodal(
[{"type": "text", "text": text}],
**kwargs
)
return result[0] if result else []
else:
return self.embed_query(text)
def embed_image(self, image_url: str, **kwargs) -> List[float]:
"""图片向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "image_url", "image_url": {"url": image_url}}],
**kwargs
)
return result[0] if result else []
def embed_video(self, video_url: str, **kwargs) -> List[float]:
"""视频向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "video_url", "video_url": {"url": video_url}}],
**kwargs
)
return result[0] if result else []
def embed_batch(
self,
items: List[Union[str, Dict[str, Any]]],
**kwargs
) -> List[List[float]]:
"""
批量向量化(支持混合类型)
Args:
items: 可以是字符串列表或内容字典列表
**kwargs: 其他参数
Returns:
向量列表
"""
# 如果全是字符串,使用标准方法
if all(isinstance(item, str) for item in items):
return self.embed_documents(items)
# 如果包含字典,需要多模态支持
if not self._is_volcano:
raise NotImplementedError(
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
# 标准化输入格式
contents = []
for item in items:
if isinstance(item, str):
contents.append({"type": "text", "text": item})
elif isinstance(item, dict):
contents.append(item)
else:
raise ValueError(f"不支持的输入类型: {type(item)}")
return self.embed_multimodal(contents, **kwargs)
# ==================== 工具方法 ====================
def is_multimodal_supported(self) -> bool:
"""检查是否支持多模态"""
return self._is_volcano
def get_provider(self) -> str:
"""获取 provider"""
return self._config.provider
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
RedBearMultimodalEmbeddings = RedBearEmbeddings

View File

@@ -0,0 +1,344 @@
"""
图片和视频生成模型封装
支持的 Provider:
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
- OpenAI: 使用 openai SDK
"""
from typing import Any, Dict, Optional
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.images.images import (
SequentialImageGenerationOptions,
ContentGenerationTool,
OptimizePromptOptions
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelProvider
class RedBearImageGenerator:
"""图片生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
# elif provider == ModelProvider.OPENAI:
# from openai import OpenAI
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的图片生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
sequential_image_generation: Optional[str] = None,
sequential_image_generation_options: Optional[Dict] = None,
tools: Optional[list] = None,
optimize_prompt_options: Optional[Dict] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
prompt: 提示词
image: 参考图片URL或URL列表图文生图/多图融合)
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080"至少3686400像素
output_format: 输出格式,如 "png", "jpg"
response_format: 返回格式,"url""b64_json"
watermark: 是否添加水印
sequential_image_generation: 组图生成模式,"auto""disabled"
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
stream: 是否使用流式生成
**kwargs: 其他参数
Returns:
生成结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {
"model": self._config.model_name,
"prompt": prompt,
"size": size,
"output_format": output_format,
"response_format": response_format,
"watermark": watermark,
}
if image is not None:
params["image"] = image
if sequential_image_generation:
params["sequential_image_generation"] = sequential_image_generation
if sequential_image_generation_options:
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
**sequential_image_generation_options
)
if tools:
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
if optimize_prompt_options:
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
if stream:
params["stream"] = True
params.update(kwargs)
response = self._client.images.generate(**params)
# elif provider == ModelProvider.OPENAI:
# response = self._client.images.generate(
# model=self._config.model_name,
# prompt=prompt,
# size=size,
# n=n,
# **kwargs
# )
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
**kwargs
) -> Dict[str, Any]:
"""异步生成图片"""
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
class RedBearVideoGenerator:
"""视频生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的视频生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image_url: Optional[str] = None,
first_frame_url: Optional[str] = None,
last_frame_url: Optional[str] = None,
reference_images: Optional[list] = None,
draft_task_id: Optional[str] = None,
duration: Optional[int] = None,
frames: Optional[int] = None,
ratio: Optional[str] = None,
resolution: Optional[str] = None,
generate_audio: bool = False,
watermark: bool = False,
camera_fixed: bool = False,
seed: Optional[int] = None,
return_last_frame: bool = False,
service_tier: str = "default",
execution_expires_after: Optional[int] = None,
draft: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
prompt: 提示词
image_url: 首帧图片URL图生视频-基于首帧)
first_frame_url: 首帧图片URL图生视频-基于首尾帧)
last_frame_url: 尾帧图片URL图生视频-基于首尾帧)
reference_images: 参考图片URL列表图生视频-基于参考图)
draft_task_id: Draft任务ID基于Draft生成正式视频
duration: 视频时长与frames二选一
frames: 视频帧数与duration二选一
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
resolution: 视频分辨率,如 "720p", "1080p"
generate_audio: 是否生成音频
watermark: 是否添加水印
camera_fixed: 是否固定镜头
seed: 随机种子
return_last_frame: 是否返回最后一帧
service_tier: 服务层级,"default""flex"(离线推理)
execution_expires_after: 任务过期时间(秒)
draft: 是否生成样片
**kwargs: 其他参数
Returns:
生成结果包含任务ID需要轮询获取结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
content = [{"type": "text", "text": prompt}]
if draft_task_id:
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
else:
if image_url:
content.append({"type": "image_url", "image_url": {"url": image_url}})
if first_frame_url:
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
if last_frame_url:
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
if reference_images:
for ref_url in reference_images:
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
if duration:
params["duration"] = duration
if frames:
params["frames"] = frames
if ratio:
params["ratio"] = ratio
if resolution:
params["resolution"] = resolution
if generate_audio:
params["generate_audio"] = generate_audio
if camera_fixed:
params["camera_fixed"] = camera_fixed
if seed is not None:
params["seed"] = seed
if return_last_frame:
params["return_last_frame"] = return_last_frame
if service_tier != "default":
params["service_tier"] = service_tier
if execution_expires_after:
params["execution_expires_after"] = execution_expires_after
if draft:
params["draft"] = draft
params.update(kwargs)
response = self._client.content_generation.tasks.create(**params)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image_url: Optional[str] = None,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""异步生成视频"""
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
response = self._client.content_generation.tasks.get(task_id=task_id)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
"""异步查询任务状态"""
return self.get_task_status(task_id)
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
"""
查询视频生成任务列表
Args:
page_size: 每页数量
status: 任务状态筛选,如 "succeeded", "failed", "pending"
**kwargs: 其他参数
Returns:
任务列表
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {"page_size": page_size}
if status:
params["status"] = status
params.update(kwargs)
response = self._client.content_generation.tasks.list(**params)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def delete_task(self, task_id: str) -> None:
"""
删除或取消视频生成任务
Args:
task_id: 任务ID
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
self._client.content_generation.tasks.delete(task_id=task_id)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)

View File

@@ -0,0 +1,334 @@
provider: volcano
models:
# Doubao-Seed 2.0 系列
- name: doubao-seed-2-0-pro-260215
type: chat
provider: volcano
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-lite-260215
type: chat
provider: volcano
description: 面向高频企业场景兼顾性能与成本的均衡型模型综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-mini-260215
type: chat
provider: volcano
description: 面向低时延、高并发与成本敏感场景提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解适合成本和速度优先的轻量级任务。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-code-preview-260215
type: chat
provider: volcano
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
# Doubao-Seed 1.x 系列
- name: doubao-seed-1-8-251228
type: chat
provider: volcano
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面视觉基础能力显著提升可低帧率理解超长视频视频运动理解、复杂空间理解及文档结构化解析能力也有所优化还原生支持智能上下文管理用户可配置上下文策略。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-251015
type: chat
provider: volcano
description: Doubao-Seed-1.6全新多模态深度思考模型同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-lite-251015
type: chat
provider: volcano
description: 更高性价比常见任务的最佳选择支持minimal、low、medium、high 四种reasoning_effort思考深度
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-flash-250828
type: chat
provider: volcano
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型TPOT低至10ms 同时支持文本和视觉理解文本理解能力超过上一代lite视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-code-preview-251028
type: chat
provider: volcano
description: 面向Agentic编程任务进行了深度优化。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
- name: doubao-seed-1-6-vision-250815
type: chat
provider: volcano
description: 全新Doubao-Seed-1.6系列视觉深度思考模型视觉理解能力显著增强并支持image_process视觉工具
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
# Doubao 1.5 系列
- name: doubao-1-5-vision-pro-32k-250115
type: chat
provider: volcano
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
- name: doubao-1-5-pro-32k-250115
type: chat
provider: volcano
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-1-5-lite-32k-250115
type: chat
provider: volcano
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
# Doubao-Seedance 视频生成系列
- name: doubao-seedance-1-5-pro-251215
type: video
provider: volcano
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-250528
type: video
provider: volcano
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-fast-251015
type: video
provider: volcano
description: 一款价格触底、效能封顶的全面模型在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-lite-i2v-250428
type: video
provider: volcano
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
- 图生视频
logo: volcano
- name: doubao-seedance-1-0-lite-t2v-250428
type: video
provider: volcano
description: 基于文本提示词生成视频
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 视频生成
- 文生视频
logo: volcano
# Doubao-Seedream 图像生成系列
- name: doubao-seedream-5-0-260128
type: image
provider: volcano
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-5-251128
type: image
provider: volcano
description: 字节跳动最新推出的图像多模态模型整合了文生图、图生图、组图输出等能力融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-0-250828
type: image
provider: volcano
description: 基于领先架构的SOTA级多模态图像创作模型其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-3-0-t2i-250415
type: image
provider: volcano
description: 一款支持原生高分辨率的中英双语图像生成基础模型综合能力媲美GPT-4o处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 图像生成
- 文生图
logo: volcano
# Doubao 翻译系列
- name: doubao-seed-translation-250915
type: chat
provider: volcano
description: 通用多语言翻译模型支持30余种语言互译支持 4K 上下文窗口,输出长度支持最大 3K tokens
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 翻译模型
logo: volcano
# Doubao Embedding 系列
- name: doubao-embedding-vision-251215
type: embedding
provider: volcano
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 向量模型
- 多模态模型
logo: volcano

View File

@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
class ElasticSearchVector(BaseVector): class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
# self.embeddings = XinferenceEmbeddings(
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port # 初始化 Embedding 模型(自动支持火山引擎多模态)
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
# )
# Remove debug printing to avoid leaking sensitive information
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
self.embeddings = RedBearEmbeddings(RedBearModelConfig( self.embeddings = RedBearEmbeddings(RedBearModelConfig(
model_name=embedding_config.model_name, model_name=embedding_config.model_name,
provider=embedding_config.provider, provider=embedding_config.provider,
api_key=embedding_config.api_key, api_key=embedding_config.api_key,
base_url=embedding_config.api_base base_url=embedding_config.api_base
)) ))
# self.reranker = XinferenceRerank( self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
# model_uid="bge-reranker-large"
# )
# Remove debug printing to avoid leaking sensitive information
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
self.reranker = RedBearRerank(RedBearModelConfig( self.reranker = RedBearRerank(RedBearModelConfig(
model_name=reranker_config.model_name, model_name=reranker_config.model_name,
provider=reranker_config.provider, provider=reranker_config.provider,
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
def add_chunks(self, chunks: list[DocumentChunk], **kwargs): def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量 # 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
embeddings = self.embeddings.embed_documents(list(texts)) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
else:
embeddings = self.embeddings.embed_documents(list(texts))
self.create(chunks, embeddings, **kwargs) self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
updated count. updated count.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
chunk.vector = self.embeddings.embed_query(chunk.page_content) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
body = { body = {
"script": { "script": {
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
"""Search the nearest neighbors to a vector.""" """Search the nearest neighbors to a vector."""
query_vector = self.embeddings.embed_query(query) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
query_vector = self.embeddings.embed_text(query)
else:
query_vector = self.embeddings.embed_query(query)
top_k = kwargs.get("top_k", 1024) top_k = kwargs.get("top_k", 1024)
score_threshold = float(kwargs.get("score_threshold") or 0.3) score_threshold = float(kwargs.get("score_threshold") or 0.3)
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"

View File

@@ -109,17 +109,13 @@ class StorageBackend(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
""" self,
Get an access URL for the file. file_key: str,
expires: int = 3600,
Args: file_name: Optional[str] = None
file_key: Unique identifier for the file in the storage system. ) -> str:
expires: URL validity period in seconds (default: 1 hour). """Get an access URL for the file."""
Returns:
URL for accessing the file.
"""
pass pass
async def get_permanent_url(self, file_key: str) -> Optional[str]: async def get_permanent_url(self, file_key: str) -> Optional[str]:

View File

@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
cause=e, cause=e,
) )
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None
) -> str:
""" """
Get an access URL for the file. Get an access URL for the file.
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (not used for local storage). expires: URL validity period in seconds (not used for local storage).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A relative URL path for accessing the file. A relative URL path for accessing the file.

View File

@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
import io import io
import logging import logging
import urllib.parse
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
import oss2 import oss2
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to check file existence in OSS {file_key}: {e}") logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
url = self.bucket.sign_url("GET", file_key, expires) params = {}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -6,6 +6,7 @@ using the boto3 SDK.
""" """
import io import io
import urllib.parse
import logging import logging
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to check file existence in S3 {file_key}: {e}") logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
params = {"Bucket": self.bucket_name, "Key": file_key}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.client.generate_presigned_url( url = self.client.generate_presigned_url(
"get_object", "get_object",
Params={ Params=params,
"Bucket": self.bucket_name,
"Key": file_key,
},
ExpiresIn=expires, ExpiresIn=expires,
) )
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefineition from app.core.workflow.adapters.errors import ExceptionDefinition
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
EdgeDefinition, EdgeDefinition,
NodeDefinition, NodeDefinition,
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel): class WorkflowImportResult(BaseModel):
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class BasePlatformAdapter(ABC): class BasePlatformAdapter(ABC):

View File

@@ -9,9 +9,9 @@ from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ( from app.core.workflow.adapters.errors import (
UnsupportVariableType, UnsupportedVariableType,
UnknowModelWarning, UnknownModelWarning,
ExceptionDefineition, ExceptionDefinition,
ExceptionType ExceptionType
) )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
HttpFormData, HttpFormData,
HttpTimeOutConfig, HttpTimeOutConfig,
HttpRetryConfig, HttpRetryConfig,
HttpErrorDefaultTamplete, HttpErrorDefaultTemplate,
HttpErrorHandleConfig HttpErrorHandleConfig
) )
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
try: try:
return config.model_validate(value) return config.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
def _process_list_variable_litearl(self, variable_selector: list) -> str | None: def _process_list_variable_literal(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)): if not self.process_var_selector(".".join(variable_selector)):
return None return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
var_type = self.variable_type_map(var["type"]) var_type = self.variable_type_map(var["type"])
if not var_type: if not var_type:
self.errors.append( self.errors.append(
UnsupportVariableType( UnsupportedVariableType(
scope=node["id"], scope=node["id"],
name=var["variable"], name=var["variable"],
var_type=var["type"], var_type=var["type"],
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
if var_type in ["file", "array[file]"]: if var_type in ["file", "array[file]"]:
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
def convert_question_classifier_node_config(self, node: dict) -> dict: def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
) )
result = QuestionClassifierNodeConfig.model_construct( result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories, categories=categories,
).model_dump() ).model_dump()
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
def convert_llm_node_config(self, node: dict) -> dict: def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
) )
) )
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting( memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")), enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
) )
) )
vision = node_data["vision"]["enabled"] vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl( vision_input = self._process_list_variable_literal(
node_data["vision"]["configs"]["variable_selector"] node_data["vision"]["configs"]["variable_selector"]
) if vision else None ) if vision else None
result = LLMNodeConfig.model_construct( result = LLMNodeConfig.model_construct(
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
conditions.append( conditions.append(
LoopConditionDetail.model_construct( LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]), operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]), left=self._process_list_variable_literal(condition["variable_selector"]),
right=self.trans_variable_format( right=self.trans_variable_format(
right_value right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"] right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"]) right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE: if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable.get("value", "")) right_value = self._process_list_variable_literal(variable.get("value", ""))
else: else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append( loop_variables.append(
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
def convert_iteration_node_config(self, node: dict) -> dict: def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
result = IterationNodeConfig.model_construct( result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]), input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data["is_parallel"], parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"], parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]), output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")), output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"], flatten=node_data["flatten_output"],
).model_dump() ).model_dump()
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
continue continue
assignments.append( assignments.append(
AssignmentItem( AssignmentItem(
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
value=self._process_list_variable_litearl( value=self._process_list_variable_literal(
assignment["value"] assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"]) operation=self.convert_assignment_operator(assignment["operation"])
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
input_variables.append( input_variables.append(
InputVariable.model_construct( InputVariable.model_construct(
name=input_variable["variable"], name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]), variable=self._process_list_variable_literal(input_variable["value_selector"]),
) )
) )
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
default_header = var["value"] default_header = var["value"]
elif var["key"] == "status_code": elif var["key"] == "status_code":
default_status_code = var["value"] default_status_code = var["value"]
default_value = HttpErrorDefaultTamplete( default_value = HttpErrorDefaultTemplate(
body=default_body, body=default_body,
headers=default_header, headers=default_header,
status_code=default_status_code, status_code=default_status_code,
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
for variable in node_data["variables"]: for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct( mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"], name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"]) value=self._process_list_variable_literal(variable["value_selector"])
)) ))
result = JinjaRenderNodeConfig.model_construct( result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"], template=node_data["template"],
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
def convert_knowledge_node_config(self, node: dict) -> dict: def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.", detail=f"Please reconfigure the Knowledge Retrieval node.",
)) ))
result = KnowledgeRetrievalNodeConfig.model_construct( result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]), query=self._process_list_variable_literal(node_data["query_variable_selector"]),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
def convert_parameter_extractor_node_config(self, node: dict) -> dict: def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
) )
) )
result = ParameterExtractorNodeConfig.model_construct( result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]), text=self._process_list_variable_literal(node_data["query"]),
params=params, params=params,
prompt=node_data.get("instruction") prompt=node_data.get("instruction")
).model_dump() ).model_dump()
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
group_type = {} group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [ group_variables = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
group_type["output"] = node_data["output_type"] group_type["output"] = node_data["output_type"]
else: else:
for group in advanced_settings["groups"]: for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [ group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in group["variables"] for variable in group["variables"]
] ]
group_type[group["group_name"]] = group["output_type"] group_type[group["group_name"]] = group["output_type"]
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
def convert_tool_node_config(self, node: dict) -> dict: def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.dify.converter import DifyConverter from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
NodeDefinition, NodeDefinition,
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app", {}).get("mode") == "workflow": if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.PLATFORM, type=ExceptionType.PLATFORM,
detail="workflow mode is not supported" detail="workflow mode is not supported"
)) ))
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edge = self._convert_edge(edge) edge = self._convert_edge(edge)
if edge: if edge:
self.edges.append(edge) self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"): for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable) con_var = self._convert_variable(variable)
if variable: if variable:
self.conv_variables.append(con_var) self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"): # for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables) # variable = self._convert_variable(variables)
# conv_variables.append(variable) # conv_variables.append(variable)
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"y": node["position"]["y"] + position["y"] "y": node["position"]["y"] + position["y"]
} }
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
detail="parent cycle node not found" detail="parent cycle node not found"
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_data = node["data"] node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try: try:
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
label = None label = None
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
label=label, label=label,
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}", detail=f"convert edge error - {e}",
)) ))
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
description=variable.get("description") description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}", detail=f"convert variable error - {e}",

View File

@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
UNKNOWN = "unknown" UNKNOWN = "unknown"
class ExceptionDefineition(BaseModel): class ExceptionDefinition(BaseModel):
type: ExceptionType type: ExceptionType
detail: str detail: str
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
name: str | None = None name: str | None = None
class UnknowModelWarning(ExceptionDefineition): class UnknownModelWarning(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name): def __init__(self, node_id, node_name, model_name):
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
) )
class UnknowError(ExceptionDefineition): class UnknownError(ExceptionDefinition):
type: ExceptionType = ExceptionType.UNKNOWN type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs): def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs) super().__init__(detail=detail, **kwargs)
class UnsupportPlatform(ExceptionDefineition): class UnsupportedPlatform(ExceptionDefinition):
type: ExceptionType = ExceptionType.PLATFORM type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str): def __init__(self, platform: str):
super().__init__(detail=f"Unsupport platform {platform}") super().__init__(detail=f"Unsupported platform {platform}")
class UnsupportVariableType(ExceptionDefineition): class UnsupportedVariableType(ExceptionDefinition):
type: ExceptionType = ExceptionType.VARIABLE type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs): def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs) super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefineition): class InvalidConfiguration(ExceptionDefinition):
type: ExceptionType = ExceptionType.CONFIG type: ExceptionType = ExceptionType.CONFIG
def __init__(self): def __init__(self):
super().__init__(detail="Invalid workflow configuration format") super().__init__(detail="Invalid workflow configuration format")
class UnsupportNodeType(ExceptionDefineition): class UnsupportedNodeType(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str): def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")

View File

@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter, BasePlatformAdapter,
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
node_type = self.map_node_type(node["type"]) node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType( self.errors.append(UnsupportedNodeType(
node_id=node_id, node_id=node_id,
node_type=node["type"] node_type=node["type"]
)) ))
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
return NodeDefinition(**node) return NodeDefinition(**node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try: try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found" detail=f"edge {edge.get('id')} skipped: source or target node not found"
)) ))
return None return None
return EdgeDefinition(**edge) return EdgeDefinition(**edge)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}" detail=f"convert edge error - {e}"
)) ))
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
return VariableDefinition(**variable) return VariableDefinition(**variable)
except Exception as e: except Exception as e:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}" detail=f"convert variable error - {e}"

View File

@@ -1,6 +1,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import ( from app.core.workflow.nodes.configs import (
StartNodeConfig, StartNodeConfig,
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
try: try:
return config_cls.model_validate(value) return config_cls.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,

View File

@@ -7,7 +7,7 @@ import re
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Iterable from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END from langgraph.graph import START, END
@@ -41,48 +41,31 @@ class GraphBuilder:
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
stream: bool = False, stream: bool = False,
subgraph: bool = False, cycle: str = '',
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
): ):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.stream = stream self.stream = stream
self.subgraph = subgraph self.cycle = cycle
self.start_node_id: str | None = None self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes} self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache( self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
maxsize=len(self.nodes) * 2
)(self._find_upstream_activation_dep)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState) self.graph: StateGraph | None = None
self.add_nodes() self.nodes: list = []
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.edges: list = []
self.end_nodes = [ self.reachable_nodes: set[str] | None = None
node self.end_nodes: list[dict] = []
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj() self._adj: dict[str, list[str]] = defaultdict(list)
self._analyze_end_node_output()
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def get_node_type(self, node_id: str) -> str: def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID. """Retrieve the type of node given its ID.
@@ -108,13 +91,14 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _build_reverse_adj(self): def _build_adj(self):
for edge in self.edges: for edge in self.edges:
if edge["source"] not in self.reachable_nodes: if edge["source"] not in self.reachable_nodes:
continue continue
self._reverse_adj[edge.get("target")].append({ self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label") "id": edge["source"], "branch": edge.get("label")
}) })
self._adj[edge.get("source")].append(edge["target"])
def _find_upstream_activation_dep( def _find_upstream_activation_dep(
self, self,
@@ -302,22 +286,13 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") if node_id not in self.reachable_nodes:
if cycle_node: continue
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES: if node_type in BRANCH_NODES:
@@ -390,6 +365,8 @@ class GraphBuilder:
for edge in self.edges: for edge in self.edges:
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
if source not in self.reachable_nodes or target not in self.reachable_nodes:
continue
condition = edge.get("condition") condition = edge.get("condition")
edge_type = edge.get("type") edge_type = edge.get("type")
@@ -411,11 +388,12 @@ class GraphBuilder:
# Add conditional edges # Add conditional edges
for source_node, branches in conditional_edges.items(): for source_node, branches in conditional_edges.items():
def make_router(src, branch_list): def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging.""" """Create a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets): def make_branch_node(node_name, targets):
def node(s): def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE # NOTE: NOP NODE USED FOR ROUTING ONLY.
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
return { return {
"activate": { "activate": {
node_id: s["activate"][node_name] node_id: s["activate"][node_name]
@@ -502,14 +480,52 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for end_node in self.end_nodes: for node in self.reachable_nodes:
end_node_id = end_node.get("id") if not self._adj[node]:
if end_node_id: self.graph.add_edge(node, END)
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
self.graph = self.graph.compile(checkpointer=checkpointer) return self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
def build_final_output( def build_final_output(
self, self,
result: dict, result: dict,
execution_context: ExecutionContext,
variable_pool: VariablePool, variable_pool: VariablePool,
elapsed_time: float, elapsed_time: float,
final_output: str, final_output: str,
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
- "node_outputs" (dict): Outputs of executed nodes. - "node_outputs" (dict): Outputs of executed nodes.
- "messages" (list): Conversation messages exchanged during execution. - "messages" (list): Conversation messages exchanged during execution.
- "error" (str, optional): Error message if any node failed. - "error" (str, optional): Error message if any node failed.
execution_context (ExecutionContext): The execution context containing metadata like
execution ID, workspace ID, and user ID.)
variable_pool (VariablePool): Variable Pool variable_pool (VariablePool): Variable Pool
elapsed_time (float): Total execution time in seconds. elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow final_output (Any): The aggregated or final output content of the workflow
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
""" """
node_outputs = result.get("node_outputs", {}) node_outputs = result.get("node_outputs", {})
token_usage = self.aggregate_token_usage(node_outputs) token_usage = self.aggregate_token_usage(node_outputs)
conversation_id = variable_pool.get_value("sys.conversation_id") conversation_vars = {}
sys_vars = {}
if variable_pool:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
return { return {
"status": "completed" if success else "failed", "status": "completed" if success else "failed",
"output": final_output, "output": final_output,
"variables": { "variables": {
"conv": variable_pool.get_all_conversation_vars(), "conv": conversation_vars,
"sys": variable_pool.get_all_system_vars() "sys": sys_vars
}, },
"node_outputs": node_outputs, "node_outputs": node_outputs,
"messages": result.get("messages", []), "messages": result.get("messages", []),
"conversation_id": conversation_id, "conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": result.get("error"), "error": result.get("error"),

View File

@@ -12,14 +12,29 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
conversation_id: str
memory_storage_type: str
user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@classmethod @classmethod
def create(cls, execution_id: str, workspace_id: str, user_id: str): def create(
cls,
execution_id: str,
workspace_id: str,
user_id: str,
conversation_id: str,
memory_storage_type: str,
user_rag_memory_id: str
):
return cls( return cls(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=conversation_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id,
checkpoint_config=RunnableConfig( checkpoint_config=RunnableConfig(
configurable={ configurable={
"thread_id": uuid.uuid4(), "thread_id": uuid.uuid4(),

View File

@@ -33,6 +33,8 @@ class WorkflowState(dict):
"workspace_id", "workspace_id",
"user_id", "user_id",
"activate", "activate",
"memory_storage_type",
"user_rag_memory_id"
}) })
__optional_keys__ = frozenset({ __optional_keys__ = frozenset({
"error", "error",
@@ -62,6 +64,9 @@ class WorkflowState(dict):
# node activate status # node activate status
activate: Annotated[dict[str, bool], merge_activate_state] activate: Annotated[dict[str, bool], merge_activate_state]
memory_storage_type: str
user_rag_memory_id: str
class WorkflowStateManager: class WorkflowStateManager:
def create_initial_state( def create_initial_state(
@@ -85,7 +90,9 @@ class WorkflowStateManager:
looping=0, looping=0,
activate={ activate={
start_node_id: True start_node_id: True
} },
memory_storage_type=execution_context.memory_storage_type,
user_rag_memory_id=execution_context.user_rag_memory_id
) )
@staticmethod @staticmethod

View File

@@ -3,7 +3,7 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11 # @Time : 2026/2/9 15:11
import re import re
from queue import Queue from collections import deque
from typing import AsyncGenerator from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@@ -256,7 +256,7 @@ class StreamOutputCoordinator:
def __init__(self): def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {} self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None self.activate_end: str | None = None
self.output_queue: Queue = Queue() self.output_queue: deque[str] = deque()
self.processed_outputs = [] self.processed_outputs = []
def initialize_end_outputs( def initialize_end_outputs(
@@ -266,7 +266,7 @@ class StreamOutputCoordinator:
self.end_outputs = end_node_map self.end_outputs = end_node_map
self.processed_outputs = [] self.processed_outputs = []
self.activate_end = None self.activate_end = None
self.output_queue = Queue() self.output_queue = deque()
@property @property
def current_activate_end_info(self): def current_activate_end_info(self):
@@ -296,13 +296,13 @@ class StreamOutputCoordinator:
scope (str): The node ID or scope that has completed execution. scope (str): The node ID or scope that has completed execution.
status (str | None): Optional status of the node (used for branch/control nodes). status (str | None): Optional status of the node (used for branch/control nodes).
""" """
for node in self.end_outputs.keys(): for node in self.end_outputs:
self.end_outputs[node].update_activate(scope, status) self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and node not in self.processed_outputs: if self.end_outputs[node].activate and node not in self.processed_outputs:
self.output_queue.put(node) self.output_queue.append(node)
self.processed_outputs.append(node) self.processed_outputs.append(node)
if self.activate_end is None and not self.output_queue.empty(): if self.activate_end is None and self.output_queue:
self.activate_end = self.output_queue.get_nowait() self.activate_end = self.output_queue.popleft()
async def emit_activate_chunk( async def emit_activate_chunk(
self, self,
@@ -414,8 +414,8 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True): async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event yield msg_event
if not self.output_queue.empty(): if self.output_queue:
self.activate_end = self.output_queue.get_nowait() self.activate_end = self.output_queue.popleft()
# Move to next active End node if current one is done # Move to next active End node if current one is done
if not self.activate_end and self.end_outputs: if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0] self.activate_end = list(self.end_outputs.keys())[0]

View File

@@ -13,7 +13,7 @@ from pydantic import BaseModel
from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.variable_objects import T, create_variable_instance from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -373,6 +373,16 @@ class VariablePool:
def copy(self, pool: 'VariablePool'): def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)
def is_file_variable(self, selector):
variable_struct = self.get_instance(selector, default=None, strict=False)
if variable_struct is None:
return False
if isinstance(variable_struct, FileVariable):
return True
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
return True
return False
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""导出为字典 """导出为字典

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 13:51 # @Time : 2026/2/9 13:51
import datetime import datetime
import time
import logging import logging
from typing import Any from typing import Any
@@ -82,13 +83,15 @@ class WorkflowExecutor:
CompiledStateGraph: The compiled and ready-to-run state graph. CompiledStateGraph: The compiled and ready-to-run state graph.
""" """
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
start_time = time.time()
builder = GraphBuilder( builder = GraphBuilder(
self.workflow_config, self.workflow_config,
stream=stream, stream=stream,
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler( self.event_handler = EventStreamHandler(
@@ -96,7 +99,8 @@ class WorkflowExecutor:
variable_pool=self.variable_pool, variable_pool=self.variable_pool,
execution_id=self.execution_context.execution_id execution_id=self.execution_context.execution_id
) )
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
f"cost: {time.time() - start_time:.4f}s")
return self.graph return self.graph
@@ -134,94 +138,12 @@ class WorkflowExecutor:
return event.get("data") return event.get("data")
return self.result_builder.build_final_output( return self.result_builder.build_final_output(
{"error": "Workflow execution did not end as expected"}, {"error": "Workflow execution did not end as expected"},
self.execution_context,
self.variable_pool, self.variable_pool,
(datetime.datetime.now() - start).total_seconds(), (datetime.datetime.now() - start).total_seconds(),
"", "",
success=False success=False
) )
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
#
# start_time = datetime.datetime.now()
#
# # Execute the workflow
# try:
# # Build the workflow graph
# graph = self.build_graph()
#
# # Initialize the variable pool with input data
# await self.variable_initializer.initialize(
# variable_pool=self.variable_pool,
# input_data=input_data,
# execution_context=self.execution_context
# )
# initial_state = self.state_manager.create_initial_state(
# workflow_config=self.workflow_config,
# input_data=input_data,
# execution_context=self.execution_context,
# start_node_id=self.start_node_id
# )
#
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
#
# # Aggregate output from all End nodes
# full_content = ''
# for end_id in self.stream_coordinator.end_outputs.keys():
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
#
# # Append messages for user and assistant
# if input_data.get("files"):
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "user",
# "content": input_data.get("files")
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# else:
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# # Calculate elapsed time
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.info(
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
#
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
#
# except Exception as e:
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
# exc_info=True)
# return {
# "status": "failed",
# "error": str(e),
# "output": None,
# "node_outputs": {},
# "elapsed_time": elapsed_time,
# "token_usage": None
# }
async def execute_stream( async def execute_stream(
self, self,
@@ -255,7 +177,7 @@ class WorkflowExecutor:
"data": { "data": {
"execution_id": self.execution_context.execution_id, "execution_id": self.execution_context.execution_id,
"workspace_id": self.execution_context.workspace_id, "workspace_id": self.execution_context.workspace_id,
"conversation_id": input_data.get("conversation_id"), "conversation_id": self.execution_context.conversation_id,
"timestamp": int(start_time.timestamp() * 1000) "timestamp": int(start_time.timestamp() * 1000)
} }
} }
@@ -376,6 +298,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -396,6 +319,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -409,7 +333,9 @@ async def execute_workflow(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Execute a workflow (convenience function, non-streaming). Execute a workflow (convenience function, non-streaming).
@@ -420,6 +346,8 @@ async def execute_workflow(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Returns: Returns:
dict: Workflow execution result. dict: Workflow execution result.
@@ -427,7 +355,10 @@ async def execute_workflow(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
conversation_id=input_data.get("conversation_id"),
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,
@@ -441,7 +372,9 @@ async def execute_workflow_stream(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
): ):
""" """
Execute a workflow in streaming mode (convenience function). Execute a workflow in streaming mode (convenience function).
@@ -452,6 +385,8 @@ async def execute_workflow_stream(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Yields: Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
@@ -459,7 +394,10 @@ async def execute_workflow_stream(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
memory_storage_type=memory_storage_type,
conversation_id=input_data.get("conversation_id"),
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,

View File

@@ -65,8 +65,6 @@ class AgentNode(BaseNode):
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
return release, message return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode): class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method. All node types should inherit from this class and implement the `execute` method.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node. """Initialize the node.
Args: Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle") self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False). their activation status (True/False).
""" """
edges = self.workflow_config.get("edges") activate_flag = self.check_activate(state)
under_stream_nodes = [
edge.get("target") if self.node_type not in BRANCH_NODES:
for edge in edges activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES else:
] activate = {}
return {
"activate": { activate[self.node_id] = activate_flag
node_id: self.check_activate(state)
for node_id in under_stream_nodes return {"activate": activate}
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -315,8 +314,8 @@ class BaseNode(ABC):
elapsed_time = (time.time() - start_time) * 1000 elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, " logger.debug(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow. raises an exception to stop the workflow.
""" """
# Check if the node has an error edge defined # # Check if the node has an error edge defined
error_edge = self._find_error_edge() # error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes) # Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool) input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message "error": error_message
} }
if error_edge: # if error_edge:
# If an error edge exists, log a warning and continue to error node # # If an error edge exists, log a warning and continue to error node
logger.warning( # logger.warning(
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
) # )
return { # return {
"node_outputs": { # "node_outputs": {
self.node_id: node_output # self.node_id: node_output
}, # },
"error": error_message, # "error": error_message,
"error_node": self.node_id # "error_node": self.node_id
} # }
else: # else:
# If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer()
writer = get_stream_writer() writer({
writer({ "type": "node_error",
"type": "node_error", **node_output
**node_output })
}) logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") raise Exception(f"Node {self.node_id} execution failed: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).
@@ -623,7 +621,6 @@ class BaseNode(ABC):
async def process_message( async def process_message(
api_config: ModelInfo, api_config: ModelInfo,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider provider = api_config.provider
@@ -642,10 +639,10 @@ class BaseNode(ABC):
return content return content
elif isinstance(content, FileObject): elif isinstance(content, FileObject):
if content.content_cache.get(provider): if content.content_cache.get(f"{provider}_{api_config.is_omni}"):
return content.content_cache[provider] return content.content_cache[f"{provider}_{api_config.is_omni}"]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, api_config=api_config) multimodal_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
type=content.type, type=content.type,
url=content.url, url=content.url,
@@ -654,16 +651,15 @@ class BaseNode(ABC):
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodal_service.process_files(
end_user_id,
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
if message: if message:
content.content_cache[provider] = message content.content_cache[f"{provider}_{api_config.is_omni}"] = message
return message return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpected input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode): class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph. It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT} outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
else: else:
remain_edges.append(edge) remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges # # Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [ # self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id # node for node in nodes if node.get("cycle") != self.node_id
] # ]
self.workflow_config["edges"] = remain_edges # self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution 3. Compile the graph for runtime execution
""" """
from app.core.workflow.engine.graph_builder import GraphBuilder from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool() self.child_variable_pool = VariablePool()
builder = GraphBuilder( builder = GraphBuilder(
{ {
"nodes": self.cycle_nodes, "nodes": self.cycle_nodes,
"edges": self.cycle_edges, "edges": self.cycle_edges,
}, },
subgraph=True, variable_pool=self.child_variable_pool,
variable_pool=self.child_variable_pool cycle=self.node_id
) )
self.start_node_id = builder.start_node_id
self.graph = builder.build() self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises: Raises:
RuntimeError: If the node type is unsupported. RuntimeError: If the node type is unsupported.
""" """
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
yield { yield {
"__final__": True, "__final__": True,

View File

@@ -0,0 +1,4 @@
from .config import DocExtractorNodeConfig
from .node import DocExtractorNode
__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"]

View File

@@ -0,0 +1,18 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class DocExtractorNodeConfig(BaseNodeConfig):
file_selector: str = Field(
...,
description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}"
)
class Config:
json_schema_extra = {
"examples": [
{
"file_selector": "{{ sys.files }}"
}
]
}

View File

@@ -0,0 +1,103 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.schemas.app_schema import FileInput, FileType, TransferMethod
logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput."""
return FileInput(
type=FileType.DOCUMENT,
transfer_method=TransferMethod(f.transfer_method),
url=f.url or None,
upload_file_id=f.file_id or None,
file_type=f.origin_file_type or "",
)
def _normalise_files(val: Any) -> list[FileObject]:
if isinstance(val, FileObject):
return [val]
if isinstance(val, dict) and val.get("is_file"):
return [FileObject(**val)]
if isinstance(val, list):
result: list[FileObject] = []
for item in val:
if isinstance(item, FileObject):
result.append(item)
elif isinstance(item, dict) and item.get("is_file"):
result.append(FileObject(**item))
else:
logger.warning("Ignoring non-file entry in file list for document extractor: %r", item)
return result
return []
class DocExtractorNode(BaseNode):
"""Document Extractor Node.
Reads one or more file variables and extracts their text content
by delegating to MultimodalService._extract_document_text.
Outputs:
text (string) full concatenated text of all input files
chunks (array[string]) per-file extracted text
"""
def _output_types(self) -> dict[str, VariableType]:
return {
"text": VariableType.STRING,
"chunks": VariableType.ARRAY_STRING,
}
def _extract_output(self, business_result: Any) -> Any:
return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config)
raw_val = self.get_variable(config.file_selector, variable_pool, strict=False)
if raw_val is None:
logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty")
return {"text": "", "chunks": []}
files = _normalise_files(raw_val)
if not files:
return {"text": "", "chunks": []}
chunks: list[str] = []
with get_db_read() as db:
from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db)
for f in files:
try:
file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files
if not file_input.url:
file_input.url = await svc.get_file_url(file_input)
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
exc_info=True,
)
chunks.append("")
full_text = "\n\n".join(c for c in chunks if c)
logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}")
return {"text": full_text, "chunks": chunks}

View File

@@ -1,9 +1,7 @@
"""End 节点配置""" """End 节点配置"""
from pydantic import Field from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig): class EndNodeConfig(BaseNodeConfig):

View File

@@ -36,8 +36,6 @@ class EndNode(BaseNode):
Returns: Returns:
最终输出字符串 最终输出字符串
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
output = self._render_template(output_template, variable_pool, strict=False) output = self._render_template(output_template, variable_pool, strict=False)
else: else:
output = "" output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output

View File

@@ -23,12 +23,13 @@ class NodeType(StrEnum):
BREAK = "break" BREAK = "break"
MEMORY_READ = "memory-read" MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor"
UNKNOWN = "unknown" UNKNOWN = "unknown"
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
) )
class HttpErrorDefaultTamplete(BaseModel): class HttpErrorDefaultTemplate(BaseModel):
body: str = Field( body: str = Field(
default="", default="",
description="Default body returned on HTTP error", description="Default body returned on HTTP error",
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'", description="Error handling strategy: 'none', 'default', or 'branch'",
) )
default: HttpErrorDefaultTamplete | None = Field( default: HttpErrorDefaultTemplate | None = Field(
default=None, default=None,
description="Default response template for error handling", description="Default response template for error handling",
) )

View File

@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.utils.file_processer import mime_to_file_type from app.core.workflow.utils.file_processor import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod from app.schemas import FileType, TransferMethod
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled. or a branch identifier string when error branching is enabled.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None self.typed_config: LLMNodeConfig | None = None
self.messages = [] self.messages = []
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
"content": await self.process_message( "content": await self.process_message(
model_info, model_info,
content, content,
user_id,
self.typed_config.vision, self.typed_config.vision,
) )
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) content = await self.process_message(model_info, file.value, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(model_info, file, user_id, self.typed_config.vision) content = await self.process_message(model_info, file, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
message["content"] = await self.process_message( message["content"] = await self.process_message(
model_info, model_info,
message["content"], message["content"],
user_id,
self.typed_config.vision self.typed_config.vision
) )
history_message.append(message) history_message.append(message)

View File

@@ -1,3 +1,4 @@
import re
from typing import Any from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -5,14 +6,16 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read from app.db import get_db_read
from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -36,19 +39,32 @@ class MemoryReadNode(BaseNode):
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,
storage_type="neo4j", storage_type=state["memory_storage_type"],
user_rag_memory_id="" user_rag_memory_id=state["user_rag_memory_id"]
) )
class MemoryWriteNode(BaseNode): class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
@staticmethod
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
variables = variable_pattern.findall(content)
file_variables = []
for variable in variables:
if variable_pool.is_file_variable(variable):
file_variables.append(variable)
for var in file_variables:
content = content.replace(var, "")
return file_variables, content
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", variable_pool) end_user_id = self.get_variable("sys.user_id", variable_pool)
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
}) })
for message in self.typed_config.messages: for message in self.typed_config.messages:
file_variables, content = self._extract_multimodal_memory_variables(
message.content,
variable_pool
)
file_info = []
for var in file_variables:
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
if isinstance(instence, FileVariable):
file_info.append(FileInput(
type=instence.value.type,
transfer_method=instence.value.transfer_method,
upload_file_id=instence.value.file_id,
url=instence.value.url,
file_type=instence.value.origin_file_type
).model_dump())
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
for file_instence in instence.value:
file_info.append(FileInput(
type=file_instence.value.type,
transfer_method=file_instence.value.transfer_method,
upload_file_id=file_instence.value.file_id,
url=file_instence.value.url,
file_type=file_instence.value.origin_file_type
).model_dump())
messages.append({ messages.append({
"role": message.role, "role": message.role,
"content": self._render_template(message.content, variable_pool) "content": self._render_template(content, variable_pool),
"files": file_info
}) })
write_message_task.delay( write_message_task.delay(
end_user_id, end_user_id=end_user_id,
messages, message=messages,
str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
"neo4j", storage_type=state["memory_storage_type"],
"" user_rag_memory_id=state["user_rag_memory_id"]
) )
return "success" return "success"

View File

@@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,7 +50,8 @@ WorkflowNode = Union[
ToolNode, ToolNode,
MemoryReadNode, MemoryReadNode,
MemoryWriteNode, MemoryWriteNode,
CodeNode CodeNode,
DocExtractorNode
] ]
@@ -81,6 +83,7 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode, NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
} }
@classmethod @classmethod
@@ -104,13 +107,15 @@ class NodeFactory:
def create_node( def create_node(
cls, cls,
node_config: dict[str, Any], node_config: dict[str, Any],
workflow_config: dict[str, Any] workflow_config: dict[str, Any],
down_stream_nodes: list[str]
) -> WorkflowNode | None: ) -> WorkflowNode | None:
"""创建节点实例 """创建节点实例
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns: Returns:
节点实例或 None对于不支持的节点类型 节点实例或 None对于不支持的节点类型
@@ -127,7 +132,7 @@ class NodeFactory:
# 创建节点实例 # 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config) return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod @classmethod
def get_supported_types(cls) -> list[str]: def get_supported_types(cls) -> list[str]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {} self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {} self.response_metadata = {}

View File

@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。 注意:变量的验证和默认值处理由 Executor 在初始化时完成。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""初始化 Start 节点 super().__init__(node_config, workflow_config, down_stream_nodes)
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置 # 解析并验证配置
self.typed_config: StartNodeConfig | None = None self.typed_config: StartNodeConfig | None = None
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典 包含系统参数、会话变量和自定义变量的字典
""" """
self.typed_config = StartNodeConfig(**self.config) self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 处理自定义变量(传入 pool 避免重复创建) # 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(variable_pool) custom_vars = self._process_custom_variables(variable_pool)
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
**custom_vars # 自定义变量作为节点输出的一部分 **custom_vars # 自定义变量作为节点输出的一部分
} }
logger.info( logger.debug(
f"节点 {self.node_id} (Start) 执行完成," f"Node {self.node_id} (Start) execution completed, "
f"输出了 {len(custom_vars)} 个自定义变量" f"outputting {len(custom_vars)} custom variables"
) )
return result return result

View File

@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode): class ToolNode(BaseNode):
"""工具节点""" """工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode): class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式) # 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
@@ -184,7 +185,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=strict) renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns: Returns:
错误列表 错误列表
""" """
return _default_renderer.validate(template) return _strict_renderer.validate(template)

View File

@@ -6,6 +6,7 @@
import copy import copy
import logging import logging
from collections import defaultdict, deque
from typing import Any, Union, TYPE_CHECKING from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -119,7 +120,6 @@ class WorkflowValidator:
errors = [] errors = []
graphs = cls.get_subgraph(workflow_config) graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for index, graph in enumerate(graphs): for index, graph in enumerate(graphs):
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
edges = graph.get("edges", []) edges = graph.get("edges", [])
@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle: if has_cycle:
errors.append( errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
) )
# 8. 验证变量名 # 8. 验证变量名
@@ -204,18 +204,18 @@ class WorkflowValidator:
Returns: Returns:
可达节点 ID 集合 可达节点 ID 集合
""" """
adj = defaultdict(list)
for edge in edges:
adj[edge["source"]].append(edge["target"])
reachable = {start_id} reachable = {start_id}
queue = [start_id] queue = deque([start_id])
while queue: while queue:
current = queue.pop(0) current = queue.popleft()
for edge in edges: for target in adj[current]:
if edge.get("source") == current: if target not in reachable:
target = edge.get("target") reachable.add(target)
if target and target not in reachable: queue.append(target)
reachable.add(target)
queue.append(target)
return reachable return reachable
@staticmethod @staticmethod
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns: Returns:
(has_cycle, cycle_path): 是否有循环和循环路径 (has_cycle, cycle_path): 是否有循环和循环路径
""" """
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {} graph: dict[str, list[str]] = {}
for edge in edges: for edge in edges:
source = edge.get("source") source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error": if edge_type == "error":
continue continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target: if source and target:
if source not in graph: if source not in graph:
graph[source] = [] graph[source] = []

View File

@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
def valid_value(self, value) -> dict: def valid_value(self, value) -> dict:
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError(f"Value must be a dict - {type(value)}:{value}") raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:

View File

@@ -30,6 +30,9 @@ class MemoryConfig(Base):
llm_id = Column(String, nullable=True, comment="LLM模型配置ID") llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
video_id = Column(String, nullable=True, comment="视频模型配置ID")
# 记忆萃取引擎配置 # 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")

View File

@@ -2,10 +2,11 @@ import datetime
import uuid import uuid
from enum import StrEnum from enum import StrEnum
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from app.db import Base from app.db import Base
@@ -26,9 +27,9 @@ class ModelType(StrEnum):
RERANK = "rerank" RERANK = "rerank"
# TTS = "tts" # TTS = "tts"
# SPEECH2TEXT = "speech2text" # SPEECH2TEXT = "speech2text"
# IMAGE = "image" IMAGE = "image"
# AUDIO = "audio" # AUDIO = "audio"
# VISION = "vision" VIDEO = "video"
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
@@ -45,6 +46,7 @@ class ModelProvider(StrEnum):
XINFERENCE = "xinference" XINFERENCE = "xinference"
GPUSTACK = "gpustack" GPUSTACK = "gpustack"
BEDROCK = "bedrock" BEDROCK = "bedrock"
VOLCANO = "volcano"
COMPOSITE = "composite" COMPOSITE = "composite"

View File

@@ -24,6 +24,21 @@ class Tenants(Base):
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
# 租户联系信息
contact_name = Column(String(100), nullable=True) # 联系人姓名
contact_email = Column(String(255), nullable=True) # 联系人邮箱
contact_phone = Column(String(50), nullable=True) # 联系人电话
# 租户套餐信息
plan = Column(String(50), nullable=True) # 套餐类型
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
status = Column(String(50), nullable=True, default='active') # 租户状态
# 租户功能开关字段
feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单")
feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单")
# Relationship to users - one tenant has many users # Relationship to users - one tenant has many users
users = relationship("User", back_populates="tenant") users = relationship("User", back_populates="tenant")

View File

@@ -9,7 +9,7 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
username = Column(String, unique=True, index=True, nullable=False) username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一
email = Column(String, unique=True, index=True, nullable=False) email = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False) hashed_password = Column(String, nullable=False)
is_active = Column(Boolean, default=True, nullable=False) is_active = Column(Boolean, default=True, nullable=False)

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import func from sqlalchemy import func
from uuid import UUID from uuid import UUID
from typing import Dict from typing import Dict, Optional, Any
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.models.user_model import User from app.models.user_model import User
@@ -191,3 +191,62 @@ class HomePageRepository:
user_count_dict = {workspace_id: count for workspace_id, count in user_counts} user_count_dict = {workspace_id: count for workspace_id, count in user_counts}
return workspaces, app_count_dict, user_count_dict return workspaces, app_count_dict, user_count_dict
@staticmethod
def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]:
"""
从数据库获取版本说明(优先读取已发布的版本)
使用反射方式读取表结构,不依赖 premium 模型类
Args:
db: 数据库会话
version: 版本号,如 "v0.2.7"
Returns:
版本说明字典,格式与 version_info.json 一致
如果数据库中没有该版本,返回 None
"""
try:
from sqlalchemy import Table, MetaData
metadata = MetaData()
version_notes = Table('version_notes', metadata, autoload_with=db.engine)
version_note_items = Table('version_note_items', metadata, autoload_with=db.engine)
note = db.query(version_notes).filter(
version_notes.c.version == version,
version_notes.c.is_published == True
).first()
if not note:
return None
items = db.query(version_note_items).filter(
version_note_items.c.note_id == note.id
).order_by(version_note_items.c.sort_order).all()
core_upgrades = []
for item in items:
title = item.title
content = item.content
if content:
core_upgrades.append(f"{title}<br>{content}")
else:
core_upgrades.append(title)
return {
"introduction": {
"codeName": "",
"releaseDate": note.release_date.isoformat() if note.release_date else "",
"upgradePosition": "",
"coreUpgrades": core_upgrades
},
"introduction_en": {
"codeName": "",
"releaseDate": note.release_date.isoformat() if note.release_date else "",
"upgradePosition": "",
"coreUpgrades": core_upgrades
}
}
except Exception:
return None

View File

@@ -9,21 +9,22 @@ Classes:
""" """
import uuid import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from uuid import UUID
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger from app.core.logging_config import get_config_logger, get_db_logger
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
from app.models.workspace_model import Workspace
from app.schemas.memory_storage_schema import ( from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate, ConfigParamsCreate,
ConfigUpdate, ConfigUpdate,
ConfigUpdateExtracted, ConfigUpdateExtracted,
ConfigUpdateForget, ConfigUpdateForget,
) )
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器 # 获取数据库专用日志器
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
return memory_config_obj return memory_config_obj
@staticmethod @staticmethod
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数) """构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args: Args:
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
Returns: Returns:
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
""" """
db_logger.debug(f"更新萃取配置: config_id={update.config_id}") db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try: try:
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
db_config = db.execute(stmt).scalar_one_or_none()
if not db_config: if not db_config:
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None return None
# 更新字段映射 update_data = update.model_dump(exclude_unset=True)
field_mapping = { update_data.pop("config_id", None)
# 模型选择
"llm_id": "llm_id",
"embedding_id": "embedding_id",
"rerank_id": "rerank_id",
# 记忆萃取引擎
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
"enable_llm_disambiguation": "enable_llm_disambiguation",
"deep_retrieval": "deep_retrieval",
"t_type_strict": "t_type_strict",
"t_name_strict": "t_name_strict",
"t_overall": "t_overall",
"state": "state",
"chunker_strategy": "chunker_strategy",
# 句子提取
"statement_granularity": "statement_granularity",
"include_dialogue_context": "include_dialogue_context",
"max_context": "max_context",
# 剪枝配置
"pruning_enabled": "pruning_enabled",
"pruning_scene": "pruning_scene",
"pruning_threshold": "pruning_threshold",
# 自我反思配置
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
}
has_update = False for field, value in update_data.items():
for api_field, db_field in field_mapping.items(): setattr(db_config, field, value)
value = getattr(update, api_field, None)
if value is not None:
setattr(db_config, db_field, value)
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit() db.commit()
db.refresh(db_config) db.refresh(db_config)
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
"llm_id": db_config.llm_id, "llm_id": db_config.llm_id,
"embedding_id": db_config.embedding_id, "embedding_id": db_config.embedding_id,
"rerank_id": db_config.rerank_id, "rerank_id": db_config.rerank_id,
"vision_id": db_config.vision_id,
"audio_id": db_config.audio_id,
"video_id": db_config.video_id,
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": db_config.enable_llm_disambiguation, "enable_llm_disambiguation": db_config.enable_llm_disambiguation,
"deep_retrieval": db_config.deep_retrieval, "deep_retrieval": db_config.deep_retrieval,
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: def get_config_with_workspace(
db: Session,
config_id: uuid.UUID | int | str
) -> Optional[tuple[MemoryConfig, Workspace]]:
"""Get memory config and its associated workspace information """Get memory config and its associated workspace information
Args: Args:
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
""" """
import time import time
from app.models.workspace_model import Workspace
start_time = time.time() start_time = time.time()
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
db_logger.debug( db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace) return config, workspace
except ValueError: except ValueError:
# Re-raise known business exceptions # Re-raise known business exceptions
@@ -775,9 +744,9 @@ class MemoryConfigRepository:
@staticmethod @staticmethod
def get_with_fallback( def get_with_fallback(
db: Session, db: Session,
config_id: Optional[uuid.UUID], config_id: Optional[uuid.UUID],
workspace_id: uuid.UUID workspace_id: uuid.UUID
) -> Optional[MemoryConfig]: ) -> Optional[MemoryConfig]:
"""获取记忆配置,支持回退到工作空间默认配置 """获取记忆配置,支持回退到工作空间默认配置
@@ -807,4 +776,3 @@ class MemoryConfigRepository:
) )
return MemoryConfigRepository.get_workspace_default(db, workspace_id) return MemoryConfigRepository.get_workspace_default(db, workspace_id)

View File

@@ -1,14 +1,15 @@
from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy import and_, or_, func, desc, select
from typing import List, Optional, Dict, Any, Tuple
import uuid import uuid
from typing import List, Optional, Dict, Any, Tuple
from sqlalchemy import and_, or_, func, desc
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
from app.schemas.model_schema import ( from app.schemas.model_schema import (
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery, ModelConfigQueryNew ModelConfigQuery, ModelConfigQueryNew
) )
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器 # 获取数据库专用日志器
db_logger = get_db_logger() db_logger = get_db_logger()
@@ -137,6 +138,9 @@ class ModelConfigRepository:
type_values.append(ModelType.LLM) type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values)) filters.append(ModelConfig.type.in_(type_values))
if query.capability:
filters.append(ModelConfig.capability.contains(query.capability))
if query.is_active is not None: if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active) filters.append(ModelConfig.is_active == query.is_active)
@@ -435,7 +439,6 @@ class ModelConfigRepository:
ModelConfig.is_public ModelConfig.is_public
), ),
ModelConfig.provider == provider, ModelConfig.provider == provider,
ModelConfig.is_active,
~ModelConfig.is_composite ~ModelConfig.is_composite
) )
).all() ).all()

View File

@@ -1,17 +1,22 @@
import logging
from typing import List, Optional from typing import List, Optional
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
MEMORY_SUMMARY_NODE_SAVE
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database.""" """Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully") logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database. """Add dialogue nodes to Neo4j database.
@@ -23,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
List of created node UUIDs or None if failed List of created node UUIDs or None if failed
""" """
if not dialogues: if not dialogues:
print("No dialogues to save") logger.info("No dialogues to save")
return [] return []
try: try:
@@ -48,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating dialogue nodes: {e}") logger.error(f"Error creating dialogue nodes: {e}")
return None return None
@@ -67,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
List of created node UUIDs or None if failed List of created node UUIDs or None if failed
""" """
if not statements: if not statements:
print("No statements to save") logger.info("No statements to save")
return [] return []
try: try:
@@ -120,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} statement nodes") logger.info(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating statement nodes: {e}") logger.error(f"Error creating statement nodes: {e}")
return None return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch. """Add chunk nodes to Neo4j in batch.
@@ -138,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
List of created chunk UUIDs or None if failed List of created chunk UUIDs or None if failed
""" """
if not chunks: if not chunks:
print("No chunk nodes to add") logger.info("No chunk nodes to add")
return [] return []
try: try:
@@ -171,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
) )
created_uuids = [record["uuid"] for record in result] created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} chunk nodes") logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids return created_uuids
except Exception as e: except Exception as e:
print(f"Error creating chunk nodes: {e}") logger.error(f"Error creating chunk nodes: {e}")
return None return None
async def add_memory_summary_nodes(
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: summaries: List[MemorySummaryNode],
connector: Neo4jConnector
) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch. """Add memory summary nodes to Neo4j in batch.
Args: Args:
@@ -191,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
List of created summary node ids or None if failed List of created summary node ids or None if failed
""" """
if not summaries: if not summaries:
print("No memory summary nodes to add") logger.info("No memory summary nodes to add")
return [] return []
try: try:
@@ -217,10 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened summaries=flattened
) )
created_ids = [record.get("uuid") for record in result] created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids return created_ids
except Exception as e: except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}") logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None return None

View File

@@ -300,7 +300,7 @@ class CommunityRepository:
) )
return bool(result) return bool(result)
except Exception as e: except Exception as e:
logger.error(f"update_community_metadata failed: {e}") logger.error(f"update_community_metadata failed: {e}", exc_info=True)
return False return False
async def batch_update_community_metadata( async def batch_update_community_metadata(

View File

@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# Entity Merge Query # Entity Merge Query
MERGE_ENTITIES = """ MERGE_ENTITIES = """
MATCH (canonical:ExtractedEntity {id: $canonical_id}) MATCH (canonical:ExtractedEntity {id: $canonical_id})
@@ -829,9 +828,8 @@ neo4j_query_all = """
other as entity2 other as entity2
""" """
'''针对当前节点下扩长的句子,实体和总结''' '''针对当前节点下扩长的句子,实体和总结'''
Memory_Timeline_ExtractedEntity=""" Memory_Timeline_ExtractedEntity = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND (ms:ExtractedEntity OR ms:MemorySummary) AND (ms:ExtractedEntity OR ms:MemorySummary)
@@ -869,7 +867,7 @@ RETURN
""" """
Memory_Timeline_MemorySummary=""" Memory_Timeline_MemorySummary = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) =$id WHERE elementId(n) =$id
AND (ms:MemorySummary OR ms:ExtractedEntity) AND (ms:MemorySummary OR ms:ExtractedEntity)
@@ -904,7 +902,7 @@ RETURN
} }
) AS statement; ) AS statement;
""" """
Memory_Timeline_Statement=""" Memory_Timeline_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
@@ -947,7 +945,7 @@ RETURN
""" """
'''针对当前节点,主要获取更加完整的句子节点''' '''针对当前节点,主要获取更加完整的句子节点'''
Memory_Space_Emotion_Statement=""" Memory_Space_Emotion_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
RETURN RETURN
@@ -957,7 +955,7 @@ RETURN
n.statement AS statement; n.statement AS statement;
""" """
Memory_Space_Emotion_MemorySummary=""" Memory_Space_Emotion_MemorySummary = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -970,7 +968,7 @@ RETURN DISTINCT
e.emotion_type AS emotion_type, e.emotion_type AS emotion_type,
e.statement AS statement; e.statement AS statement;
""" """
Memory_Space_Emotion_ExtractedEntity=""" Memory_Space_Emotion_ExtractedEntity = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -985,18 +983,18 @@ RETURN DISTINCT
'''获取实体''' '''获取实体'''
Memory_Space_User=""" Memory_Space_User = """
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.end_user_id = $end_user_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id return DISTINCT elementId(m) as id
""" """
Memory_Space_Entity=""" Memory_Space_Entity = """
MATCH (n)-[]-(m) MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person" WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN RETURN
DISTINCT m.name as name,m.end_user_id as end_user_id DISTINCT m.name as name,m.end_user_id as end_user_id
""" """
Memory_Space_Associative=""" Memory_Space_Associative = """
MATCH (u)-[]-(x)-[]-(h) MATCH (u)-[]-(x)-[]-(h)
WHERE elementId(u) = $user_id WHERE elementId(u) = $user_id
AND elementId(h) = $id AND elementId(h) = $id
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
""" """
Graph_Node_query = """ Graph_Node_query = """
MATCH (n:MemorySummary) MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
0 AS priority 0 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:Dialogue) MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
1 AS priority 1 AS priority
LIMIT 1 LIMIT 1
UNION ALL UNION ALL
MATCH (n:Statement) MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
1 AS priority 1 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:ExtractedEntity) MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
2 AS priority 2 AS priority
LIMIT $limit LIMIT $limit
UNION ALL UNION ALL
MATCH (n:Chunk) MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id WHERE n.end_user_id = $end_user_id
RETURN RETURN
elementId(n) AS id, elementId(n) AS id,
labels(n) AS labels, labels(n) AS labels,
properties(n) AS properties, properties(n) AS properties,
3 AS priority 3 AS priority
LIMIT $limit LIMIT $limit
""" UNION ALL
MATCH (n:Perceptual)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
4 AS priority
"""
# ============================================================ # ============================================================
# Community 节点 & BELONGS_TO_COMMUNITY 边 # Community 节点 & BELONGS_TO_COMMUNITY 边
@@ -1069,6 +1075,7 @@ Graph_Node_query = """
COMMUNITY_NODE_UPSERT = """ COMMUNITY_NODE_UPSERT = """
MERGE (c:Community {community_id: $community_id}) MERGE (c:Community {community_id: $community_id})
ON CREATE SET c.id = $community_id
SET c.end_user_id = $end_user_id, SET c.end_user_id = $end_user_id,
c.member_count = $member_count, c.member_count = $member_count,
c.updated_at = datetime() c.updated_at = datetime()
@@ -1175,7 +1182,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """ UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name, SET c.id = coalesce(c.id, $community_id),
c.name = $name,
c.summary = $summary, c.summary = $summary,
c.core_entities = $core_entities, c.core_entities = $core_entities,
c.summary_embedding = $summary_embedding, c.summary_embedding = $summary_embedding,
@@ -1186,7 +1194,8 @@ RETURN c.community_id AS community_id
BATCH_UPDATE_COMMUNITY_METADATA = """ BATCH_UPDATE_COMMUNITY_METADATA = """
UNWIND $communities AS row UNWIND $communities AS row
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
SET c.name = row.name, SET c.id = coalesce(c.id, row.community_id),
c.name = row.name,
c.summary = row.summary, c.summary = row.summary,
c.core_entities = row.core_entities, c.core_entities = row.core_entities,
c.summary_embedding = row.summary_embedding, c.summary_embedding = row.summary_embedding,
@@ -1270,6 +1279,40 @@ RETURN
startNode(r) = e AS r_from_e startNode(r) = e AS r_from_e
""" """
CHECK_COMMUNITY_IS_COMPLETE = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL
) AS is_complete
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL AND
c.summary_embedding IS NOT NULL
) AS is_complete
"""
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.name = ''
OR c.summary IS NULL OR c.summary = ''
OR c.core_entities IS NULL
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
RETURN c.community_id AS community_id
"""
# Community keyword search: matches name or summary via fulltext index # Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """ SEARCH_COMMUNITIES_BY_KEYWORD = """
@@ -1327,37 +1370,35 @@ ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit LIMIT $limit
""" """
CHECK_COMMUNITY_IS_COMPLETE = """ # 感知记忆节点保存
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) PERCEPTUAL_NODE_SAVE = """
RETURN ( UNWIND $perceptuals AS p
c.name IS NOT NULL AND c.name <> '' AND MERGE (n:Perceptual {id: p.id})
c.summary IS NOT NULL AND c.summary <> '' AND SET n += {
c.core_entities IS NOT NULL id: p.id,
) AS is_complete end_user_id: p.end_user_id,
perceptual_type: p.perceptual_type,
file_path: p.file_path,
file_name: p.file_name,
file_ext: p.file_ext,
summary: p.summary,
keywords: p.keywords,
topic: p.topic,
domain: p.domain,
created_at: p.created_at,
file_type: p.file_type,
summary_embedding: p.summary_embedding
}
RETURN n.id AS uuid
""" """
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ # 感知记忆与对话的关联边
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) PERCEPTUAL_CHUNK_EDGE_SAVE = """
RETURN ( UNWIND $edges AS edge
c.name IS NOT NULL AND c.name <> '' AND MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
c.summary IS NOT NULL AND c.summary <> '' AND MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
c.core_entities IS NOT NULL AND MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
c.summary_embedding IS NOT NULL ON CREATE SET r.end_user_id = edge.end_user_id,
) AS is_complete r.created_at = edge.created_at
""" RETURN elementId(r) AS uuid
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.name = ''
OR c.summary IS NULL OR c.summary = ''
OR c.core_entities IS NULL
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
RETURN c.community_id AS community_id
""" """

View File

@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
StatementNode, StatementNode,
ExtractedEntityNode, ExtractedEntityNode,
EntityEntityEdge, EntityEntityEdge,
PerceptualNode,
PerceptualEdge,
) )
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def save_entities_and_relationships( async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save entities and their relationships using graph models""" """Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes] all_entities = [entity.model_dump() for entity in entity_nodes]
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
async def save_chunk_nodes( async def save_chunk_nodes(
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save chunk nodes using graph models""" """Save chunk nodes using graph models"""
if not chunk_nodes: if not chunk_nodes:
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
async def save_statement_chunk_edges( async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save statement-chunk edges using graph models""" """Save statement-chunk edges using graph models"""
if not statement_chunk_edges: if not statement_chunk_edges:
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
async def save_statement_entity_edges( async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
): ):
"""Save statement-entity edges using graph models""" """Save statement-entity edges using graph models"""
if not statement_entity_edges: if not statement_entity_edges:
@@ -154,24 +159,28 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
perceptual_nodes: List[PerceptualNode],
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
perceptual_edges: List[PerceptualEdge],
connector: Neo4jConnector, connector: Neo4jConnector,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
schedule_clustering_after_write() 显式触发。 _trigger_clustering_sync() 显式触发。
Args: Args:
dialogue_nodes: List of DialogueNode objects to save dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save entity_nodes: List of ExtractedEntityNode objects to save
perceptual_nodes: List of PerceptualNode objects to save
entity_edges: List of EntityEntityEdge objects to save entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save
perceptual_edges: List of PerceptualEdge objects to save
connector: Neo4j connector instance connector: Neo4j connector instance
Returns: Returns:
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result] dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
# 2. Save all chunk nodes in batch # 2. Save all chunk nodes in batch
if chunk_nodes: if chunk_nodes:
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
results['chunks'] = chunk_uuids results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
if perceptual_nodes:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
perceptual_data = [node.model_dump() for node in perceptual_nodes]
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
perceptual_uuids = [record["uuid"] async for record in result]
results["perceptuals"] = perceptual_uuids
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
# 3. Save all statement nodes in batch # 3. Save all statement nodes in batch
if statement_nodes: if statement_nodes:
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
results['statement_entity_edges'] = se_uuids results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
if perceptual_edges:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
perceptual_edge_data = []
for edge in perceptual_edges:
print(edge.source, edge.target)
perceptual_edge_data.append({
"perceptual_id": edge.source,
"chunk_id": edge.target,
"end_user_id": edge.end_user_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
})
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
perceptual_edges_uuids = [record["uuid"] async for record in result]
results['perceptual_chunk_edges'] = perceptual_edges_uuids
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
return results return results
try: try:
@@ -303,16 +336,13 @@ async def save_dialog_and_statements_to_neo4j(
return False return False
def schedule_clustering_after_write( async def _trigger_clustering_sync(
entity_nodes: List, entity_nodes: List,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
写入 Neo4j 成功后,调度后台聚类任务 同步等待聚类完成,避免与其他 LLM 任务并发冲突
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
使用 asyncio.create_task 异步触发,不阻塞写入响应。
""" """
if not entity_nodes: if not entity_nodes:
return return
@@ -324,15 +354,16 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes] new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id)
async def _trigger_clustering( async def _trigger_clustering(
new_entity_ids: List[str], new_entity_ids: List[str],
end_user_id: str, end_user_id: str,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。

View File

@@ -196,6 +196,13 @@ class CitationConfig(BaseModel):
enabled: bool = Field(default=False) enabled: bool = Field(default=False)
class Citation(BaseModel):
document_id: str
file_name: str
knowledge_id: str
score: float
class WebSearchConfig(BaseModel): class WebSearchConfig(BaseModel):
"""联网搜索配置""" """联网搜索配置"""
enabled: bool = Field(default=False) enabled: bool = Field(default=False)

View File

@@ -387,6 +387,12 @@ class MemoryConfig:
rerank_model_id: Optional[UUID] = None rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None rerank_model_name: Optional[str] = None
video_model_id: Optional[UUID] = None
video_model_name: Optional[str] = None
vision_model_id: Optional[UUID] = None
vision_model_name: Optional[str] = None
audio_model_id: Optional[UUID] = None
audio_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict) llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict)

View File

@@ -8,9 +8,6 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================ # ============================================================================
# 从 json_schema.py 迁移的 Schema # 从 json_schema.py 迁移的 Schema
# ============================================================================ # ============================================================================
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
class ConflictResultSchema(BaseModel): class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file.""" """Schema for the conflict result data in the reflexion_data.json file."""
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") data: List[BaseDataSchema] = Field(...,
description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.") conflict: bool = Field(..., description="Whether the memory is in conflict.")
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
memory_verify: Optional[MemoryVerifySchema] = Field(None,
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_data(cls, v): def _normalize_data(cls, v):
@@ -105,12 +105,15 @@ class ChangeRecordSchema(BaseModel):
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
) )
class ResolvedSchema(BaseModel): class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data""" """Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
change: Optional[List[ChangeRecordSchema]] = Field(None,
description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel): class SingleReflexionResultSchema(BaseModel):
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.") type: str = Field("reflexion_result", description="The type identifier.")
class ReflexionResultSchema(BaseModel): class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions.""" """Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") results: List[SingleReflexionResultSchema] = Field(...,
description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_resolved(cls, v): def _normalize_resolved(cls, v):
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row # Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str | None = Field(default=None, description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
# Allowed chunking strategies (extendable later) # Allowed chunking strategies (extendable later)
@@ -241,10 +246,12 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致") reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
audio_id: Optional[str] = Field(None, description="语音模型ID")
vision_id: Optional[str] = Field(None, description="视觉模型ID")
video_id: Optional[str] = Field(None, description="视频模型ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
enable_llm_dedup_blockwise: Optional[bool] = None enable_llm_dedup_blockwise: Optional[bool] = None
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型 # 遗忘引擎配置参数更新模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5") lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
def fail( def fail(
msg: str, msg: str,
error_code: str = "ERROR", error_code: str = "ERROR",
data: Optional[Any] = None, data: Optional[Any] = None,
time: Optional[int] = None, time: Optional[int] = None,
query_preview: Optional[str] = None, query_preview: Optional[str] = None,
) -> ApiResponse: ) -> ApiResponse:
payload = data payload = data
if query_preview is not None: if query_preview is not None:
@@ -387,6 +397,7 @@ def fail(
time=time or _now_ms(), time=time or _now_ms(),
) )
class GenerateCacheRequest(BaseModel): class GenerateCacheRequest(BaseModel):
"""缓存生成请求模型""" """缓存生成请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型""" """遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据每天取最后一次执行") recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
updated_at: datetime.datetime updated_at: datetime.datetime
api_keys: List["ModelApiKey"] = [] api_keys: List["ModelApiKey"] = []
@staticmethod
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
if not key or len(key) <= prefix + suffix:
return "*" * len(key)
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
@field_validator("api_keys", mode="after") @field_validator("api_keys", mode="after")
@classmethod @classmethod
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
def _serialize_created_at(self, dt: datetime.datetime | None): def _serialize_created_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("api_keys", when_used="json")
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
result = []
for api_key in api_keys:
data = api_key.model_dump()
data["api_key"] = self.mask_api_key(api_key.api_key)
result.append(data)
return result
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -166,8 +181,8 @@ class ModelApiKey(ModelApiKeyBase):
self.model_config_ids = [ self.model_config_ids = [
mc.id for mc in self.model_configs mc.id for mc in self.model_configs
if hasattr(mc, 'id') if hasattr(mc, 'id')
and not getattr(mc, 'is_composite', False) and not getattr(mc, 'is_composite', False)
and getattr(mc, 'name', None) == self.model_name and getattr(mc, 'name', None) == self.model_name
] ]
# 情况2字典列表 # 情况2字典列表
elif isinstance(self.model_configs, list): elif isinstance(self.model_configs, list):
@@ -193,7 +208,6 @@ class ModelApiKey(ModelApiKeyBase):
validate_assignment=True # 确保赋值触发校验 validate_assignment=True # 确保赋值触发校验
) )
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
"""模型配置查询Schema""" """模型配置查询Schema"""
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
is_active: Optional[bool] = Field(None, description="激活状态筛选") is_active: Optional[bool] = Field(None, description="激活状态筛选")
is_public: Optional[bool] = Field(None, description="公开状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
is_composite: Optional[bool] = Field(None, description="组合模型筛选") is_composite: Optional[bool] = Field(None, description="组合模型筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelMarketplace(BaseModel): class ModelMarketplace(BaseModel):
"""模型广场响应Schema""" """模型广场响应Schema"""
llm_models: List[ModelConfig] = [] llm_models: List[ModelConfig] = []
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
"""模型信息Schema""" """模型信息Schema"""
model_name: str = Field(..., description="模型名称") model_name: str = Field(..., description="模型名称")
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
is_omni: bool = Field(default=False, description="是否为omni模型") is_omni: bool = Field(default=False, description="是否为omni模型")
model_type: ModelType = Field(..., description="模型类型") model_type: ModelType = Field(..., description="模型类型")
capability: List[str] = Field(default_factory=list, description="模型能力列表") capability: List[str] = Field(default_factory=list, description="模型能力列表")

View File

@@ -82,6 +82,12 @@ class AppChatService:
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# opening_statement首轮对话注入开场白
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
system_prompt = self.agent_service._inject_opening_statement(
features_config, system_prompt, is_new_conversation
)
# 准备工具列表 # 准备工具列表
tools = [] tools = []
@@ -93,7 +99,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
memory_flag = False memory_flag = False
if memory: if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config( memory_tools, memory_flag = self.agent_service.load_memory_config(
@@ -129,45 +136,18 @@ class AppChatService:
) )
# 加载历史消息 # 加载历史消息
messages = self.conversation_service.get_messages( history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
limit=10 max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
) )
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件 # 处理多模态文件
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态 # 调用 Agent支持多模态
@@ -206,7 +186,8 @@ class AppChatService:
# 构建用户消息内容(含多模态文件) # 构建用户消息内容(含多模态文件)
human_meta = { human_meta = {
"files": [] "files": [],
"history_files": {}
} }
assistant_meta = { assistant_meta = {
"model": api_key_obj.model_name, "model": api_key_obj.model_name,
@@ -221,6 +202,13 @@ class AppChatService:
"url": f.url "url": f.url
}) })
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
# 保存消息 # 保存消息
if audio_url: if audio_url:
assistant_meta["audio_url"] = audio_url assistant_meta["audio_url"] = audio_url
@@ -249,8 +237,9 @@ class AppChatService:
}), }),
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"suggested_questions": suggested_questions, "suggested_questions": suggested_questions,
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "citations": self.agent_service._filter_citations(features_config, citations_collector),
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending"
} }
async def agnet_chat_stream( async def agnet_chat_stream(
@@ -301,6 +290,12 @@ class AppChatService:
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# opening_statement首轮对话注入开场白
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
system_prompt = self.agent_service._inject_opening_statement(
features_config, system_prompt, is_new_conversation
)
# 准备工具列表 # 准备工具列表
tools = [] tools = []
@@ -313,7 +308,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
@@ -350,45 +346,18 @@ class AppChatService:
) )
# 加载历史消息 # 加载历史消息
messages = self.conversation_service.get_messages( history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
limit=10 max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
) )
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
history_files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in history_files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件 # 处理多模态文件
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态同时并行启动 TTS # 流式调用 Agent支持多模态同时并行启动 TTS
@@ -433,7 +402,7 @@ class AppChatService:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件(包含 suggested_questions、tts、citations # 发送结束事件(包含 suggested_questions、tts、audio_status、citations
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {}) sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"): if isinstance(sq_config, dict) and sq_config.get("enabled"):
@@ -443,11 +412,23 @@ class AppChatService:
"api_base": api_key_obj.api_base}, {} "api_base": api_key_obj.api_base}, {}
) )
end_data["audio_url"] = stream_audio_url end_data["audio_url"] = stream_audio_url
end_data["citations"] = self.agent_service._filter_citations(features_config, []) # 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector)
# 保存消息 # 保存消息
human_meta = { human_meta = {
"files":[] "files":[],
"history_files": {}
} }
assistant_meta = { assistant_meta = {
"model": api_key_obj.model_name, "model": api_key_obj.model_name,
@@ -457,11 +438,16 @@ class AppChatService:
if files: if files:
for f in files: for f in files:
# url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({ human_meta["files"].append({
"type": f.type, "type": f.type,
"url": f.url "url": f.url
}) })
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
if stream_audio_url: if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url assistant_meta["audio_url"] = stream_audio_url

View File

@@ -1638,7 +1638,7 @@ class AppService:
# ==================== 记忆配置提取方法 ==================== # ==================== 记忆配置提取方法 ====================
def _extract_memory_config_id( def _get_memory_config_id_from_release(
self, self,
app_type: str, app_type: str,
config: Dict[str, Any] config: Dict[str, Any]
@@ -1863,7 +1863,7 @@ class AppService:
self.db.flush() # 先 flush确保 release 已插入数据库 self.db.flush() # 先 flush确保 release 已插入数据库
# 提取记忆配置ID并更新终端用户 # 提取记忆配置ID并更新终端用户
memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(app.type, config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置 # 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id: if is_legacy_int and not memory_config_id:
@@ -2001,7 +2001,7 @@ class AppService:
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
# 提取记忆配置ID并更新终端用户 # 提取记忆配置ID并更新终端用户
memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(release.type, release.config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置 # 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id: if is_legacy_int and not memory_config_id:

View File

@@ -274,7 +274,8 @@ class ConversationService:
self, self,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
max_history: Optional[int] = None, max_history: Optional[int] = None,
api_config: Optional[ModelInfo] = None current_provider: Optional[str] = None,
current_is_omni: Optional[bool] = None
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve historical conversation messages formatted as dictionaries. Retrieve historical conversation messages formatted as dictionaries.
@@ -282,7 +283,8 @@ class ConversationService:
Args: Args:
conversation_id (uuid.UUID): Conversation UUID. conversation_id (uuid.UUID): Conversation UUID.
max_history (Optional[int]): Maximum number of messages to retrieve. max_history (Optional[int]): Maximum number of messages to retrieve.
api_config (Optional[ModelInfo]): Model API configuration for multimodal processing. current_provider (Optional[str]): Current provider for file handling.
current_is_omni (Optional[bool]): Current omni flag for file handling.
Returns: Returns:
List[dict]: List of message dictionaries with keys 'role' and 'content'. List[dict]: List of message dictionaries with keys 'role' and 'content'.
@@ -292,38 +294,30 @@ class ConversationService:
limit=max_history limit=max_history
) )
# 转换为字典格式
history = [] history = []
for msg in messages: for msg in messages:
content = [{"type": "text", "text": msg.content}] msg_dict = {
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
if api_config:
# 使用 MultimodalService 处理文件
from app.services.multimodal_service import MultimodalService
multimodal_service = MultimodalService(self.db, api_config=api_config)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(processed_files)
history.append({
"role": msg.role, "role": msg.role,
"content": content "content": [{"type": "text", "text": msg.content}]
}) }
# 处理用户消息中的多模态文件
if msg.role == "user" and msg.meta_data:
history_files = msg.meta_data.get("history_files", {})
if history_files and current_provider and current_is_omni is not None:
# 检查是否需要重新处理文件
stored_provider = history_files.get("provider")
stored_is_omni = history_files.get("is_omni")
# 如果provider或is_omni不匹配需要重新处理
if stored_provider != current_provider or stored_is_omni != current_is_omni:
continue
# provider和is_omni匹配直接使用存储的内容
msg_dict["content"].extend(history_files.get("content"))
history.append(msg_dict)
return history return history
@@ -539,6 +533,7 @@ class ConversationService:
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
api_base = api_config.api_base api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type model_type = config.type
llm = RedBearLLM( llm = RedBearLLM(
@@ -546,7 +541,8 @@ class ConversationService:
model_name=model_name, model_name=model_name,
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=api_base base_url=api_base,
is_omni=is_omni
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )
@@ -554,15 +550,8 @@ class ConversationService:
conversation_messages = await self.get_conversation_history( conversation_messages = await self.get_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
max_history=20, max_history=20,
api_config=ModelInfo( current_provider=provider,
model_name=model_name, current_is_omni=is_omni
provider=provider,
api_key=api_key,
api_base=api_base,
capability=api_config.capability,
is_omni=api_config.is_omni,
model_type=model_type
)
) )
if len(conversation_messages) == 0: if len(conversation_messages) == 0:
return ConversationOut( return ConversationOut(

View File

@@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service from app.services import task_service
@@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
return web_search_tool return web_search_tool
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args: Args:
kb_config: 知识库配置 kb_config: 知识库配置
kb_ids: 知识库ID列表 kb_ids: 知识库ID列表
user_id: 用户ID user_id: 用户ID
citations_collector: 用于收集引用信息的列表由外部传入tool 执行时填充)
列表元素类型为 Citation包含字段
- document_id: 文档唯一标识
- file_name: 文件名
- knowledge_id: 知识库 ID
- score: 检索相关性得分
Returns: Returns:
检索到的相关知识内容 检索到的相关知识内容
@@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
} }
) )
# 收集引用信息
if citations_collector is not None:
seen_doc_ids = {c.get("document_id") for c in citations_collector}
for chunk in retrieve_chunks_result:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations_collector.append(Citation(
document_id=doc_id,
file_name=meta.get("file_name", ""),
knowledge_id=str(meta.get("knowledge_id", "")),
score=meta.get("score", 0)
))
return f"检索到以下相关信息:\n\n{context}" return f"检索到以下相关信息:\n\n{context}"
else: else:
logger.warning("知识库检索未找到结果") logger.warning("知识库检索未找到结果")
@@ -320,26 +341,26 @@ class AgentRunService:
self, self,
knowledge_retrieval_config: dict | None, knowledge_retrieval_config: dict | None,
user_id user_id
) -> list: ) -> tuple[list, list]:
"""返回 (tools, citations_collector)"""
if not knowledge_retrieval_config: if not knowledge_retrieval_config:
return [] return [], []
citations_collector = []
tools = [] tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids: if kb_ids:
# 创建知识库检索工具 kb_tool = create_knowledge_retrieval_tool(
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) knowledge_retrieval_config, kb_ids, user_id,
citations_collector=citations_collector
)
tools.append(kb_tool) tools.append(kb_tool)
logger.debug( logger.debug(
"已添加知识库检索工具", "已添加知识库检索工具",
extra={ extra={"kb_ids": kb_ids, "tool_count": len(tools)}
"kb_ids": kb_ids,
"tool_count": len(tools)
}
) )
return tools return tools, citations_collector
def load_memory_config( def load_memory_config(
self, self,
@@ -441,12 +462,12 @@ class AgentRunService:
@staticmethod @staticmethod
def _filter_citations( def _filter_citations(
features_config: Dict[str, Any], features_config: Dict[str, Any],
citations: List[Any] citations: List[Citation]
) -> List[Any]: ) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源""" """根据 citation 开关决定是否返回引用来源"""
citation_cfg = features_config.get("citation", {}) citation_cfg = features_config.get("citation", {})
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
return citations return [cit.model_dump() for cit in citations]
return [] return []
async def run( async def run(
@@ -549,7 +570,8 @@ class AgentRunService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
@@ -592,8 +614,9 @@ class AgentRunService:
# 6. 加载历史消息 # 6. 加载历史消息
history = await self._load_conversation_history( history = await self._load_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
api_config=model_info, max_history=10,
max_history=10 current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
) )
# 6. 处理多模态文件 # 6. 处理多模态文件
@@ -602,7 +625,7 @@ class AgentRunService:
# 获取 provider 信息 # 获取 provider 信息
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -661,7 +684,10 @@ class AgentRunService:
}) })
}, },
files=files, files=files,
audio_url=audio_url processed_files=processed_files,
audio_url=audio_url,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
) )
response = { response = {
@@ -676,8 +702,9 @@ class AgentRunService:
"suggested_questions": await self._generate_suggested_questions( "suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [], ) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])), "citations": self._filter_citations(features_config, citations_collector),
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending"
} }
logger.info( logger.info(
@@ -785,7 +812,8 @@ class AgentRunService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
@@ -830,8 +858,9 @@ class AgentRunService:
# 6. 加载历史消息 # 6. 加载历史消息
history = await self._load_conversation_history( history = await self._load_conversation_history(
conversation_id=conversation_id, conversation_id=conversation_id,
api_config=model_info, max_history=memory_config.get("max_history", 10),
max_history=memory_config.get("max_history", 10) current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
) )
# 6. 处理多模态文件 # 6. 处理多模态文件
@@ -840,7 +869,7 @@ class AgentRunService:
# 获取 provider 信息 # 获取 provider 信息
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -909,10 +938,13 @@ class AgentRunService:
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}, },
files=files, files=files,
audio_url=stream_audio_url processed_files=processed_files,
audio_url=stream_audio_url,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
) )
# 12. 发送结束事件(包含 suggested_questions 和 tts # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status
end_data: Dict[str, Any] = { end_data: Dict[str, Any] = {
"conversation_id": conversation_id, "conversation_id": conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
@@ -923,7 +955,18 @@ class AgentRunService:
features_config, full_content, api_key_config, effective_params features_config, full_content, api_key_config, effective_params
) )
end_data["audio_url"] = stream_audio_url end_data["audio_url"] = stream_audio_url
end_data["citations"] = self._filter_citations(features_config, []) # 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self._filter_citations(features_config, citations_collector)
yield self._format_sse_event("end", end_data) yield self._format_sse_event("end", end_data)
logger.info( logger.info(
@@ -1119,14 +1162,17 @@ class AgentRunService:
async def _load_conversation_history( async def _load_conversation_history(
self, self,
conversation_id: str, conversation_id: str,
api_config: ModelInfo | None = None, max_history: int = 10,
max_history: int = 10 current_provider: Optional[str] = None,
current_is_omni: Optional[bool] = None
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""加载会话历史消息 """加载会话历史消息,并根据当前模型配置处理多模态文件
Args: Args:
conversation_id: 会话ID conversation_id: 会话ID
max_history: 最大历史消息数量 max_history: 最大历史消息数量
current_provider: 当前模型的provider
current_is_omni: 当前模型的is_omni
Returns: Returns:
List[Dict]: 历史消息列表 List[Dict]: 历史消息列表
@@ -1138,7 +1184,8 @@ class AgentRunService:
history = await conversation_service.get_conversation_history( history = await conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id), conversation_id=uuid.UUID(conversation_id),
max_history=max_history, max_history=max_history,
api_config=api_config current_provider=current_provider,
current_is_omni=current_is_omni
) )
logger.debug( logger.debug(
@@ -1166,7 +1213,10 @@ class AgentRunService:
app_id: Optional[uuid.UUID] = None, app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
files: Optional[List[FileInput]] = None, files: Optional[List[FileInput]] = None,
audio_url: Optional[str] = None processed_files: Optional[List[Dict[str, Any]]] = None,
audio_url: Optional[str] = None,
provider: Optional[str] = None,
is_omni: Optional[bool] = None
) -> None: ) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在) """保存会话消息(会话已通过 _ensure_conversation 确保存在)
@@ -1177,6 +1227,11 @@ class AgentRunService:
app_id: 应用ID未使用保留用于兼容性 app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性 user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗 meta_data: token消耗
files: 原始文件输入
processed_files: 处理后的文件
audio_url: 音频URL
provider: 模型供应商
is_omni: 是否为全模态模型
""" """
try: try:
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
@@ -1186,15 +1241,24 @@ class AgentRunService:
# 保存消息(会话已经存在) # 保存消息(会话已经存在)
human_meta = { human_meta = {
"files": [] "files": [],
"history_files": {}
} }
if files: if files:
for f in files: for f in files:
# url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({ human_meta["files"].append({
"type": f.type, "type": f.type,
"url": f.url "url": f.url
}) })
# 保存 history_files包含 provider 和 is_omni 信息
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": provider,
"is_omni": is_omni
}
# 保存用户消息 # 保存用户消息
conversation_service.add_message( conversation_service.add_message(
conversation_id=conv_uuid, conversation_id=conv_uuid,
@@ -1420,8 +1484,9 @@ class AgentRunService:
workspace_id: Optional[uuid.UUID] = None, workspace_id: Optional[uuid.UUID] = None,
) -> tuple[Optional[str], Optional[asyncio.Task]]: ) -> tuple[Optional[str], Optional[asyncio.Task]]:
"""文本流式输入并行合成音频。 """文本流式输入并行合成音频。
返回 (audio_url, task)audio_url 立即可用task 完成后文件内容就绪。 返回 (audio_url, task)audio_url 立即可用pending状态task 完成后文件内容就绪。
调用方向 text_queue put 文本 chunk结束时 put None。 调用方向 text_queue put 文本 chunk结束时 put None。
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
""" """
tts_config = features_config.get("text_to_speech", {}) tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"): if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
@@ -1808,6 +1873,7 @@ class AgentRunService:
), ),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"), "audio_url": result.get("audio_url"),
"audio_status": result.get("audio_status"),
"citations": result.get("citations", []), "citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []), "suggested_questions": result.get("suggested_questions", []),
"error": None "error": None
@@ -1885,6 +1951,7 @@ class AgentRunService:
"results": [{ "results": [{
**r, **r,
"audio_url": r.get("audio_url"), "audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []), "citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []), "suggested_questions": r.get("suggested_questions", []),
} for r in results], } for r in results],
@@ -2016,6 +2083,7 @@ class AgentRunService:
full_content = "" full_content = ""
returned_conversation_id = model_conversation_id returned_conversation_id = model_conversation_id
audio_url = None audio_url = None
audio_status = None
citations = [] citations = []
suggested_questions = [] suggested_questions = []
@@ -2074,6 +2142,7 @@ class AgentRunService:
# 从 end 事件中提取 features 输出字段 # 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data: if event_type == "end" and event_data:
audio_url = event_data.get("audio_url") audio_url = event_data.get("audio_url")
audio_status = event_data.get("audio_status")
citations = event_data.get("citations", []) citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", []) suggested_questions = event_data.get("suggested_questions", [])
@@ -2103,6 +2172,7 @@ class AgentRunService:
"message": full_content, "message": full_content,
"elapsed_time": elapsed, "elapsed_time": elapsed,
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": audio_status,
"citations": citations, "citations": citations,
"suggested_questions": suggested_questions, "suggested_questions": suggested_questions,
"error": None "error": None
@@ -2117,6 +2187,7 @@ class AgentRunService:
"elapsed_time": elapsed, "elapsed_time": elapsed,
"message_length": len(full_content), "message_length": len(full_content),
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": audio_status,
"citations": citations, "citations": citations,
"suggested_questions": suggested_questions, "suggested_questions": suggested_questions,
"timestamp": time.time() "timestamp": time.time()
@@ -2253,6 +2324,7 @@ class AgentRunService:
"message": r.get("message"), "message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0), "elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"), "audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []), "citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []), "suggested_questions": r.get("suggested_questions", []),
"error": r.get("error") "error": r.get("error")

View File

@@ -325,27 +325,30 @@ class FileStorageService:
) )
raise raise
async def get_file_url(self, file_key: str, expires: int = 3600) -> str: async def get_file_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get an access URL for a file. Get an access URL for a file.
Args: Args:
file_key: The file key. file_key: The file key.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
URL for accessing the file. URL for accessing the file.
""" """
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
try: try:
url = await self.storage.get_url(file_key, expires) url = await self.storage.get_url(file_key, expires, file_name=file_name)
logger.debug(f"File URL generated: file_key={file_key}") logger.debug(f"File URL generated: file_key={file_key}")
return url return url
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
f"Error getting file URL: file_key={file_key}, error={str(e)}"
)
raise raise

View File

@@ -0,0 +1,162 @@
"""
图片和视频生成服务
提供统一的生成接口,支持多种 Provider
"""
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
import uuid
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelType
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
class GenerationService:
"""生成服务"""
def __init__(self, db: Session):
self.db = db
async def generate_image(
self,
model_config_id: str,
prompt: str,
size: Optional[str] = "2k",
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
model_config_id: 模型配置ID
prompt: 提示词
size: 图片尺寸
**kwargs: 其他参数
Returns:
生成结果
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.IMAGE:
raise BusinessException(
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成图片
generator = RedBearImageGenerator(config)
result = await generator.agenerate(prompt, size, **kwargs)
return result
async def generate_video(
self,
model_config_id: str,
prompt: str,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
model_config_id: 模型配置ID
prompt: 提示词
duration: 视频时长(秒)
**kwargs: 其他参数
Returns:
生成结果包含任务ID
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.VIDEO:
raise BusinessException(
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成视频
generator = RedBearVideoGenerator(config)
result = await generator.agenerate(prompt, duration, **kwargs)
return result
async def get_video_task_status(
self,
model_config_id: str,
task_id: str
) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
model_config_id: 模型配置ID
task_id: 任务ID
Returns:
任务状态信息
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 查询任务状态
generator = RedBearVideoGenerator(config)
result = await generator.aget_task_status(task_id)
return result

View File

@@ -94,29 +94,38 @@ class HomePageService:
@staticmethod @staticmethod
def load_version_introduction(version: str) -> Dict[str, Any]: def load_version_introduction(version: str) -> Dict[str, Any]:
""" """
从 JSON 文件加载对应版本的介绍 加载对应版本的介绍优先从数据库读取fallback 到 JSON 文件)
:param version: 系统版本号(如 "0.2.0" :param version: 系统版本号(如 "0.2.0"
:return: 对应版本的详细介绍 :return: 对应版本的详细介绍
""" """
# 2. 定义 JSON 文件路径(简化路径处理,保留绝对路径调试特性)
json_abs_path = Path(__file__).parent.parent / "version_info.json"
json_abs_path = json_abs_path.resolve()
# 3. 初始化返回结果(深拷贝默认模板,避免修改原常量)
from copy import deepcopy from copy import deepcopy
from app.db import SessionLocal
from app.repositories.home_page_repository import HomePageRepository
result = deepcopy(HomePageService.DEFAULT_RETURN_DATA) result = deepcopy(HomePageService.DEFAULT_RETURN_DATA)
try: try:
# 4. 简化文件存在性判断(合并逻辑,减少分支) db = SessionLocal()
try:
db_result = HomePageRepository.get_version_introduction(db, version)
if db_result:
return db_result
finally:
db.close()
except Exception as e:
pass
json_abs_path = Path(__file__).parent.parent / "version_info.json"
json_abs_path = json_abs_path.resolve()
try:
if not json_abs_path.exists(): if not json_abs_path.exists():
result["message"] = f"版本介绍文件不存在:{json_abs_path}" result["message"] = f"版本介绍文件不存在:{json_abs_path}"
return result return result
# 5. 读取并解析 JSON 文件(简化文件操作流程)
with open(json_abs_path, "r", encoding="utf-8") as f: with open(json_abs_path, "r", encoding="utf-8") as f:
changelogs = json.load(f) changelogs = json.load(f)
# 6. 简化版本匹配逻辑,直接返回结果或更新提示信息
if version in changelogs: if version in changelogs:
return changelogs[version] return changelogs[version]
result["message"] = f"暂未查询到 {version} 版本的详细介绍" result["message"] = f"暂未查询到 {version} 版本的详细介绍"

Some files were not shown because too many files have changed in this diff Show More