Merge branch 'develop' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/app-message-log
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -11,21 +12,24 @@ from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _mask_url(url: str) -> str:
|
||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||
# 未配置则回退到 Redis 方案
|
||||
# backend: 结果存储(使用 Redis)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
@@ -45,8 +49,8 @@ celery_app = Celery(
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"broker": _mask_url(_broker_url),
|
||||
"backend": _mask_url(_backend_url),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
@@ -77,6 +81,7 @@ celery_app.conf.update(
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
@@ -57,6 +57,7 @@ def list_apps(
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -65,10 +66,25 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- 当提供 api_key 参数时,查找该 API Key 关联的应用
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
|
||||
if api_key:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == api_key,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
ids = str(matched_id) if matched_id else ""
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
|
||||
@@ -14,6 +14,9 @@ Routes:
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
import httpx
|
||||
import mimetypes
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
@@ -290,6 +293,101 @@ async def upload_file_with_share_token(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||
async def get_file_info_by_url(
|
||||
url: str,
|
||||
):
|
||||
"""
|
||||
Get file information by network URL (no authentication required).
|
||||
|
||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||
Falls back to GET request if HEAD is not supported.
|
||||
Returns file type, name, and size.
|
||||
|
||||
Args:
|
||||
url: The network URL of the file.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file information.
|
||||
"""
|
||||
api_logger.info(f"File info by URL request: url={url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try HEAD request first
|
||||
response = await client.head(url, follow_redirects=True)
|
||||
|
||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||
if response.status_code != 200:
|
||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
# Get file size from Content-Length header or actual content
|
||||
file_size = response.headers.get("Content-Length")
|
||||
if file_size:
|
||||
file_size = int(file_size)
|
||||
elif hasattr(response, 'content'):
|
||||
file_size = len(response.content)
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
# Get content type from Content-Type header
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
# Remove charset and other parameters from content type
|
||||
content_type = content_type.split(';')[0].strip()
|
||||
|
||||
# Extract filename from Content-Disposition or URL
|
||||
file_name = None
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition and "filename=" in content_disposition:
|
||||
parts = content_disposition.split("filename=")
|
||||
if len(parts) > 1:
|
||||
file_name = parts[1].strip('"').strip("'")
|
||||
|
||||
if not file_name:
|
||||
parsed_url = urlparse(url)
|
||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||
|
||||
# Extract file extension from filename
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
|
||||
# If no extension found, infer from content type
|
||||
if not file_ext:
|
||||
ext = mimetypes.guess_extension(content_type)
|
||||
if ext:
|
||||
file_ext = ext
|
||||
file_name = f"{file_name}{file_ext}"
|
||||
|
||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"file_name": file_name,
|
||||
"file_ext": file_ext.lower() if file_ext else "",
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
},
|
||||
msg="File information retrieved successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file information: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
@@ -476,8 +574,12 @@ async def get_file_url(
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||
url = await storage_service.get_file_url(
|
||||
file_key,
|
||||
expires=expires,
|
||||
file_name=file_metadata.file_name,
|
||||
)
|
||||
url = _match_scheme(request, url)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
@@ -688,7 +790,7 @@ async def permanent_download_file(
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
@@ -697,3 +799,44 @@ async def permanent_download_file(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||
async def get_file_status(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get file upload/processing status (no authentication required).
|
||||
|
||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||
Returns status: pending, completed, or failed.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file status and metadata.
|
||||
"""
|
||||
api_logger.info(f"File status request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"file_id": str(file_id),
|
||||
"status": file_metadata.status,
|
||||
"file_name": file_metadata.file_name,
|
||||
"file_size": file_metadata.file_size,
|
||||
"content_type": file_metadata.content_type,
|
||||
},
|
||||
msg="File status retrieved successfully"
|
||||
)
|
||||
|
||||
@@ -91,9 +91,11 @@ async def get_mcp_servers(
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
headers=api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
|
||||
@@ -118,142 +118,142 @@ async def download_log(
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
# @router.post("/writer_service", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Write service endpoint - processes write operations synchronously
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Response with write operation status
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
#
|
||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
# if storage_type == 'rag':
|
||||
# if workspace_id:
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge:
|
||||
# user_rag_memory_id = str(knowledge.id)
|
||||
# else:
|
||||
# api_logger.warning(
|
||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
# else:
|
||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
#
|
||||
# api_logger.info(
|
||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
# try:
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
# result = await memory_agent_service.write_memory(
|
||||
# user_input.end_user_id,
|
||||
# messages_list,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id,
|
||||
# language
|
||||
# )
|
||||
#
|
||||
# return success(data=result, msg="写入成功")
|
||||
# except BaseException as e:
|
||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
# if hasattr(e, 'exceptions'):
|
||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
# detailed_error = "; ".join(error_messages)
|
||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
#
|
||||
#
|
||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server_async(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Async write service endpoint - enqueues write processing to Celery
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Task ID for tracking async operation
|
||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(
|
||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
# if workspace_id:
|
||||
#
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# try:
|
||||
# # 获取标准化的消息列表
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
#
|
||||
# task = celery_app.send_task(
|
||||
# "app.core.memory.agent.write_message",
|
||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
# )
|
||||
# api_logger.info(f"Write task queued: {task.id}")
|
||||
#
|
||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
|
||||
@@ -663,9 +663,12 @@ async def dashboard_data(
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
rag_data["total_app"] = len(apps_orm)
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
rag_data["total_app"] = total_app
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
@@ -687,7 +690,7 @@ async def dashboard_data(
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
@@ -75,24 +75,19 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
@@ -107,9 +102,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
@@ -119,9 +116,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
@@ -129,10 +128,10 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
@@ -145,24 +144,24 @@ def delete_config(
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
@@ -172,7 +171,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
@@ -186,7 +185,7 @@ def delete_config(
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
@@ -195,7 +194,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
@@ -203,9 +202,9 @@ def delete_config(
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -213,12 +212,13 @@ def update_config(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||
"config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -231,9 +231,9 @@ def update_config(
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -241,7 +241,7 @@ def update_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -256,11 +256,11 @@ def update_config_extracted(
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -268,7 +268,7 @@ def read_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -278,18 +278,19 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -303,14 +304,14 @@ def read_all_config(
|
||||
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
@@ -333,9 +334,9 @@ async def pilot_run(
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
@@ -344,12 +345,12 @@ async def get_kb_type_distribution(
|
||||
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
@@ -361,9 +362,9 @@ async def search_dialogues_num(
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
@@ -375,9 +376,9 @@ async def search_chunks_num(
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
@@ -389,9 +390,9 @@ async def search_statements_num(
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
@@ -403,9 +404,9 @@ async def search_entities_num(
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
@@ -417,9 +418,9 @@ async def search_all_num(
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
@@ -431,9 +432,9 @@ async def search_entities_detials(
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
@@ -443,14 +444,12 @@ async def search_entity_edges(
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
@@ -461,18 +460,18 @@ async def get_hot_memory_tags_api(
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
@@ -481,11 +480,11 @@ async def get_hot_memory_tags_api(
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
@@ -495,9 +494,9 @@ async def get_hot_memory_tags_api(
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
@@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache(
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
@@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache(
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -74,10 +75,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
|
||||
@@ -669,6 +669,7 @@ async def config_query(
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
@@ -45,9 +45,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
@@ -90,7 +90,7 @@ async def get_user_summary_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -102,7 +102,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -117,10 +117,10 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
@@ -134,7 +134,7 @@ async def generate_cache_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -155,10 +155,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -209,9 +211,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -298,9 +303,9 @@ async def get_graph_data_api(
|
||||
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
@@ -385,9 +390,9 @@ async def get_end_user_profile(
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
@@ -417,7 +422,7 @@ async def update_end_user_profile(
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
@@ -427,15 +432,18 @@ async def update_end_user_profile(
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -598,8 +598,10 @@ class LangChainAgent:
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
|
||||
@@ -231,8 +231,8 @@ class Settings:
|
||||
# Celery configuration (internal)
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
|
||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
formatted_messages = redis_messages
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
@@ -6,6 +6,7 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,28 +14,28 @@ from dotenv import load_dotenv
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +44,11 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -99,14 +102,14 @@ async def write(
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
@@ -135,9 +138,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -162,18 +167,21 @@ async def write(
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(
|
||||
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
||||
await _trigger_clustering_sync(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
embedding_model_id=str(
|
||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -208,9 +216,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -251,4 +258,4 @@ async def write(
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||
自动支持火山引擎的多模态 Embedding。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
- 批量文本嵌入
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
- 火山引擎多模态 Embedding(自动识别)
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 RedBearEmbeddings 模型
|
||||
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
self.model = RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
)
|
||||
self.is_multimodal = self.model.is_multimodal_supported()
|
||||
|
||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
||||
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||
|
||||
async def response(
|
||||
self,
|
||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
return []
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
if self.is_multimodal:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = await self.model.aembed_multimodal(
|
||||
[{"type": "text", "text": text} for text in texts]
|
||||
)
|
||||
else:
|
||||
# 普通 Embedding
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
|
||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
|
||||
if isinstance(v, str):
|
||||
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
|
||||
# 支持1-4位年份
|
||||
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
|
||||
match = re.match(pattern, v)
|
||||
|
||||
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
|
||||
minute = int(match.group(5)) if match.group(5) else 0
|
||||
second = int(match.group(6)) if match.group(6) else 0
|
||||
microsecond = 0
|
||||
|
||||
|
||||
# 处理微秒
|
||||
if match.group(7):
|
||||
# 补齐或截断到6位
|
||||
us_str = match.group(7).ljust(6, '0')[:6]
|
||||
microsecond = int(us_str)
|
||||
|
||||
|
||||
# 处理时区
|
||||
tzinfo = None
|
||||
if 'Z' in v or match.group(8):
|
||||
tzinfo = timezone.utc
|
||||
|
||||
|
||||
# 创建 datetime 对象
|
||||
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
|
||||
|
||||
|
||||
except (ValueError, OverflowError):
|
||||
# 日期值无效(如月份13、日期32等)
|
||||
return None
|
||||
|
||||
|
||||
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
|
||||
try:
|
||||
return datetime.fromisoformat(v.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class PerceptualEdge(Edge):
|
||||
"""Edge connecting perceptual nodes to their source chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
@@ -241,17 +248,17 @@ class StatementNode(Node):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
@@ -264,25 +271,26 @@ class StatementNode(Node):
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -309,13 +317,13 @@ class StatementNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
@@ -326,7 +334,7 @@ class StatementNode(Node):
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
@@ -337,7 +345,7 @@ class StatementNode(Node):
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
@@ -405,19 +413,20 @@ class ExtractedEntityNode(Node):
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -444,16 +453,16 @@ class ExtractedEntityNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
@@ -507,8 +516,9 @@ class MemorySummaryNode(Node):
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
@@ -522,7 +532,7 @@ class MemorySummaryNode(Node):
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
"""
|
||||
perceptual_type: int
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_ext: str
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
file_type: str
|
||||
summary_embedding: list[float] | None
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -71,13 +71,11 @@ class LabelPropagationEngine:
|
||||
connector: Neo4jConnector,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
return
|
||||
|
||||
# 统计邻居社区分布
|
||||
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata([target_cid], end_user_id)
|
||||
# 新实体加入后成员变化,强制重新生成元数据
|
||||
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据。
|
||||
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
"""
|
||||
if not community_ids:
|
||||
return
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||
async def _build_one(cid: str):
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
async def _build_one(cid: str) -> Optional[Dict]:
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||
return None
|
||||
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
|
||||
# 方案四:注入社区内实体间关系三元组
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
name, summary = "", ""
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
if self.embedding_model_id:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
if len(metadata_list) == 1:
|
||||
@@ -558,17 +576,13 @@ class LabelPropagationEngine:
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||
if not result:
|
||||
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if ok:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
return str(uuid.uuid4())
|
||||
@@ -9,6 +9,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
@@ -706,7 +709,7 @@ class SemanticPruner:
|
||||
# 阈值保护:最高0.9
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
if proportion > 0.9:
|
||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
@@ -905,7 +908,7 @@ class SemanticPruner:
|
||||
|
||||
# Safety: avoid empty dataset
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
@@ -915,8 +918,7 @@ class SemanticPruner:
|
||||
try:
|
||||
self.run_logs.append(msg)
|
||||
except Exception:
|
||||
# 任何异常都不影响打印
|
||||
pass
|
||||
print(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
|
||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
logger.debug("=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
logger.debug("=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
logger.debug("=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
||||
from .llm import RedBearLLM
|
||||
from .embedding import RedBearEmbeddings
|
||||
from .rerank import RedBearRerank
|
||||
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||
|
||||
__all__ = [
|
||||
"RedBearModelConfig",
|
||||
@@ -9,5 +10,7 @@ __all__ = [
|
||||
"RedBearEmbeddings",
|
||||
"RedBearRerank",
|
||||
"RedBearModelFactory",
|
||||
"get_provider_llm_class"
|
||||
"get_provider_llm_class",
|
||||
"RedBearImageGenerator",
|
||||
"RedBearVideoGenerator"
|
||||
]
|
||||
@@ -67,7 +67,7 @@ class RedBearModelFactory:
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
# 这样可以分别控制连接超时和读取超时
|
||||
import httpx
|
||||
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
|
||||
@@ -1,23 +1,190 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
|
||||
class RedBearEmbeddings(Embeddings):
|
||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
||||
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||||
|
||||
if self._is_volcano:
|
||||
# 火山引擎使用 Ark SDK
|
||||
self._client = self._create_volcano_client(config)
|
||||
self._model = None
|
||||
else:
|
||||
# 其他 provider 使用 LangChain
|
||||
self._model = self._create_model(config)
|
||||
self._client = None
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建模型"""
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
from volcenginesdkarkruntime import Ark
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# ==================== LangChain 标准接口 ====================
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._model.embed_documents(texts)
|
||||
"""批量文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
contents = [{"type": "text", "text": text} for text in texts]
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
encoding_format="float"
|
||||
)
|
||||
return [response.data.embedding]
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._model.embed_query(text)
|
||||
"""单个文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
result = self.embed_documents([text])
|
||||
return result[0] if result else []
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_query(text)
|
||||
|
||||
# ==================== 多模态扩展方法 ====================
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
多模态向量化(仅火山引擎支持)
|
||||
|
||||
Args:
|
||||
contents: 内容列表,格式:
|
||||
- 文本: {"type": "text", "text": "..."}
|
||||
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||||
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
**kwargs
|
||||
)
|
||||
return [response.data.embedding]
|
||||
|
||||
async def aembed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""异步多模态向量化"""
|
||||
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||||
"""文本向量化(便捷方法)"""
|
||||
if self._is_volcano:
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "text", "text": text}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
else:
|
||||
return self.embed_query(text)
|
||||
|
||||
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||||
"""图片向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||||
"""视频向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
items: List[Union[str, Dict[str, Any]]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
批量向量化(支持混合类型)
|
||||
|
||||
Args:
|
||||
items: 可以是字符串列表或内容字典列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
# 如果全是字符串,使用标准方法
|
||||
if all(isinstance(item, str) for item in items):
|
||||
return self.embed_documents(items)
|
||||
|
||||
# 如果包含字典,需要多模态支持
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
# 标准化输入格式
|
||||
contents = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
contents.append({"type": "text", "text": item})
|
||||
elif isinstance(item, dict):
|
||||
contents.append(item)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||||
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def is_multimodal_supported(self) -> bool:
|
||||
"""检查是否支持多模态"""
|
||||
return self._is_volcano
|
||||
|
||||
def get_provider(self) -> str:
|
||||
"""获取 provider"""
|
||||
return self._config.provider
|
||||
|
||||
|
||||
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||||
RedBearMultimodalEmbeddings = RedBearEmbeddings
|
||||
|
||||
344
api/app/core/models/generation.py
Normal file
344
api/app/core/models/generation.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
图片和视频生成模型封装
|
||||
|
||||
支持的 Provider:
|
||||
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
|
||||
- OpenAI: 使用 openai SDK
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from volcenginesdkarkruntime.types.images.images import (
|
||||
SequentialImageGenerationOptions,
|
||||
ContentGenerationTool,
|
||||
OptimizePromptOptions
|
||||
)
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
|
||||
class RedBearImageGenerator:
|
||||
"""图片生成模型封装"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._config = config
|
||||
self._client = self._create_client(config)
|
||||
|
||||
def _create_client(self, config: RedBearModelConfig):
|
||||
"""根据 provider 创建客户端"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
# elif provider == ModelProvider.OPENAI:
|
||||
# from openai import OpenAI
|
||||
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的图片生成提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[Any] = None,
|
||||
size: Optional[str] = "2K",
|
||||
output_format: str = "png",
|
||||
response_format: str = "url",
|
||||
watermark: bool = False,
|
||||
sequential_image_generation: Optional[str] = None,
|
||||
sequential_image_generation_options: Optional[Dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
optimize_prompt_options: Optional[Dict] = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成图片
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
image: 参考图片URL或URL列表(图文生图/多图融合)
|
||||
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素)
|
||||
output_format: 输出格式,如 "png", "jpg"
|
||||
response_format: 返回格式,"url" 或 "b64_json"
|
||||
watermark: 是否添加水印
|
||||
sequential_image_generation: 组图生成模式,"auto" 或 "disabled"
|
||||
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
|
||||
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
|
||||
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
|
||||
stream: 是否使用流式生成
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
params = {
|
||||
"model": self._config.model_name,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"output_format": output_format,
|
||||
"response_format": response_format,
|
||||
"watermark": watermark,
|
||||
}
|
||||
|
||||
if image is not None:
|
||||
params["image"] = image
|
||||
|
||||
if sequential_image_generation:
|
||||
params["sequential_image_generation"] = sequential_image_generation
|
||||
if sequential_image_generation_options:
|
||||
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
|
||||
**sequential_image_generation_options
|
||||
)
|
||||
|
||||
if tools:
|
||||
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
|
||||
|
||||
if optimize_prompt_options:
|
||||
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
|
||||
|
||||
if stream:
|
||||
params["stream"] = True
|
||||
|
||||
params.update(kwargs)
|
||||
response = self._client.images.generate(**params)
|
||||
|
||||
# elif provider == ModelProvider.OPENAI:
|
||||
# response = self._client.images.generate(
|
||||
# model=self._config.model_name,
|
||||
# prompt=prompt,
|
||||
# size=size,
|
||||
# n=n,
|
||||
# **kwargs
|
||||
# )
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[Any] = None,
|
||||
size: Optional[str] = "2K",
|
||||
output_format: str = "png",
|
||||
response_format: str = "url",
|
||||
watermark: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""异步生成图片"""
|
||||
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
|
||||
|
||||
|
||||
class RedBearVideoGenerator:
|
||||
"""视频生成模型封装"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._config = config
|
||||
self._client = self._create_client(config)
|
||||
|
||||
def _create_client(self, config: RedBearModelConfig):
|
||||
"""根据 provider 创建客户端"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的视频生成提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
image_url: Optional[str] = None,
|
||||
first_frame_url: Optional[str] = None,
|
||||
last_frame_url: Optional[str] = None,
|
||||
reference_images: Optional[list] = None,
|
||||
draft_task_id: Optional[str] = None,
|
||||
duration: Optional[int] = None,
|
||||
frames: Optional[int] = None,
|
||||
ratio: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
generate_audio: bool = False,
|
||||
watermark: bool = False,
|
||||
camera_fixed: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
return_last_frame: bool = False,
|
||||
service_tier: str = "default",
|
||||
execution_expires_after: Optional[int] = None,
|
||||
draft: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成视频
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
image_url: 首帧图片URL(图生视频-基于首帧)
|
||||
first_frame_url: 首帧图片URL(图生视频-基于首尾帧)
|
||||
last_frame_url: 尾帧图片URL(图生视频-基于首尾帧)
|
||||
reference_images: 参考图片URL列表(图生视频-基于参考图)
|
||||
draft_task_id: Draft任务ID(基于Draft生成正式视频)
|
||||
duration: 视频时长(秒),与frames二选一
|
||||
frames: 视频帧数,与duration二选一
|
||||
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
|
||||
resolution: 视频分辨率,如 "720p", "1080p"
|
||||
generate_audio: 是否生成音频
|
||||
watermark: 是否添加水印
|
||||
camera_fixed: 是否固定镜头
|
||||
seed: 随机种子
|
||||
return_last_frame: 是否返回最后一帧
|
||||
service_tier: 服务层级,"default" 或 "flex"(离线推理)
|
||||
execution_expires_after: 任务过期时间(秒)
|
||||
draft: 是否生成样片
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果(包含任务ID,需要轮询获取结果)
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
if draft_task_id:
|
||||
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
|
||||
else:
|
||||
if image_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
if first_frame_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
|
||||
if last_frame_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
|
||||
|
||||
if reference_images:
|
||||
for ref_url in reference_images:
|
||||
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
|
||||
|
||||
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
|
||||
|
||||
if duration:
|
||||
params["duration"] = duration
|
||||
if frames:
|
||||
params["frames"] = frames
|
||||
if ratio:
|
||||
params["ratio"] = ratio
|
||||
if resolution:
|
||||
params["resolution"] = resolution
|
||||
if generate_audio:
|
||||
params["generate_audio"] = generate_audio
|
||||
if camera_fixed:
|
||||
params["camera_fixed"] = camera_fixed
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
if return_last_frame:
|
||||
params["return_last_frame"] = return_last_frame
|
||||
if service_tier != "default":
|
||||
params["service_tier"] = service_tier
|
||||
if execution_expires_after:
|
||||
params["execution_expires_after"] = execution_expires_after
|
||||
if draft:
|
||||
params["draft"] = draft
|
||||
|
||||
params.update(kwargs)
|
||||
response = self._client.content_generation.tasks.create(**params)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
image_url: Optional[str] = None,
|
||||
duration: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""异步生成视频"""
|
||||
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
response = self._client.content_generation.tasks.get(task_id=task_id)
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""异步查询任务状态"""
|
||||
return self.get_task_status(task_id)
|
||||
|
||||
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务列表
|
||||
|
||||
Args:
|
||||
page_size: 每页数量
|
||||
status: 任务状态筛选,如 "succeeded", "failed", "pending"
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
params = {"page_size": page_size}
|
||||
if status:
|
||||
params["status"] = status
|
||||
params.update(kwargs)
|
||||
response = self._client.content_generation.tasks.list(**params)
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def delete_task(self, task_id: str) -> None:
|
||||
"""
|
||||
删除或取消视频生成任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
self._client.content_generation.tasks.delete(task_id=task_id)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
@@ -0,0 +1,334 @@
|
||||
provider: volcano
|
||||
models:
|
||||
# Doubao-Seed 2.0 系列
|
||||
- name: doubao-seed-2-0-pro-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-lite-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-mini-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-code-preview-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seed 1.x 系列
|
||||
- name: doubao-seed-1-8-251228
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-251015
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-lite-251015
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-flash-250828
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-code-preview-251028
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向Agentic编程任务进行了深度优化。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-vision-250815
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao 1.5 系列
|
||||
- name: doubao-1-5-vision-pro-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-1-5-pro-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-1-5-lite-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seedance 视频生成系列
|
||||
- name: doubao-seedance-1-5-pro-251215
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-pro-250528
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-pro-fast-251015
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-lite-i2v-250428
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
- 图生视频
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-lite-t2v-250428
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 基于文本提示词生成视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
- 文生视频
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seedream 图像生成系列
|
||||
- name: doubao-seedream-5-0-260128
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-4-5-251128
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-4-0-250828
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-3-0-t2i-250415
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
- 文生图
|
||||
logo: volcano
|
||||
|
||||
# Doubao 翻译系列
|
||||
- name: doubao-seed-translation-250915
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 翻译模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao Embedding 系列
|
||||
- name: doubao-embedding-vision-251215
|
||||
type: embedding
|
||||
provider: volcano
|
||||
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 向量模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
|
||||
class ElasticSearchVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
super().__init__(index_name.lower())
|
||||
# self.embeddings = XinferenceEmbeddings(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
|
||||
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
|
||||
|
||||
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=embedding_config.model_name,
|
||||
provider=embedding_config.provider,
|
||||
api_key=embedding_config.api_key,
|
||||
base_url=embedding_config.api_base
|
||||
))
|
||||
# self.reranker = XinferenceRerank(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
|
||||
# model_uid="bge-reranker-large"
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
|
||||
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
|
||||
|
||||
self.reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=reranker_config.model_name,
|
||||
provider=reranker_config.provider,
|
||||
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
|
||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||
# 实现 Elasticsearch 保存向量
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = self.embeddings.embed_batch(texts)
|
||||
else:
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
self.create(chunks, embeddings, **kwargs)
|
||||
|
||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
|
||||
updated count.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||
else:
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
|
||||
body = {
|
||||
"script": {
|
||||
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
||||
"""Search the nearest neighbors to a vector."""
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
query_vector = self.embeddings.embed_text(query)
|
||||
else:
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
top_k = kwargs.get("top_k", 1024)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
|
||||
@@ -109,17 +109,13 @@ class StorageBackend(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
"""
|
||||
Get an access URL for the file.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
expires: URL validity period in seconds (default: 1 hour).
|
||||
|
||||
Returns:
|
||||
URL for accessing the file.
|
||||
"""
|
||||
async def get_url(
|
||||
self,
|
||||
file_key: str,
|
||||
expires: int = 3600,
|
||||
file_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get an access URL for the file."""
|
||||
pass
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||
|
||||
@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
async def get_url(
|
||||
self,
|
||||
file_key: str,
|
||||
expires: int = 3600,
|
||||
file_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get an access URL for the file.
|
||||
|
||||
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
expires: URL validity period in seconds (not used for local storage).
|
||||
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||
|
||||
Returns:
|
||||
A relative URL path for accessing the file.
|
||||
|
||||
@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import oss2
|
||||
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
|
||||
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
|
||||
return False
|
||||
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
async def get_url(
|
||||
self,
|
||||
file_key: str,
|
||||
expires: int = 3600,
|
||||
file_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get a presigned URL for accessing the file.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
expires: URL validity period in seconds (default: 1 hour).
|
||||
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||
|
||||
Returns:
|
||||
A presigned URL for accessing the file.
|
||||
"""
|
||||
try:
|
||||
url = self.bucket.sign_url("GET", file_key, expires)
|
||||
params = {}
|
||||
if file_name:
|
||||
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
|
||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||
# Return a basic URL format as fallback
|
||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> str:
|
||||
|
||||
@@ -6,6 +6,7 @@ using the boto3 SDK.
|
||||
"""
|
||||
|
||||
import io
|
||||
import urllib.parse
|
||||
import logging
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
|
||||
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
|
||||
return False
|
||||
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
async def get_url(
|
||||
self,
|
||||
file_key: str,
|
||||
expires: int = 3600,
|
||||
file_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get a presigned URL for accessing the file.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
expires: URL validity period in seconds (default: 1 hour).
|
||||
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||
|
||||
Returns:
|
||||
A presigned URL for accessing the file.
|
||||
"""
|
||||
try:
|
||||
params = {"Bucket": self.bucket_name, "Key": file_key}
|
||||
if file_name:
|
||||
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||
url = self.client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={
|
||||
"Bucket": self.bucket_name,
|
||||
"Key": file_key,
|
||||
},
|
||||
Params=params,
|
||||
ExpiresIn=expires,
|
||||
)
|
||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||
# Return a basic URL format as fallback
|
||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> str:
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition
|
||||
from app.schemas.workflow_schema import (
|
||||
EdgeDefinition,
|
||||
NodeDefinition,
|
||||
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowImportResult(BaseModel):
|
||||
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
|
||||
@@ -9,9 +9,9 @@ from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import (
|
||||
UnsupportVariableType,
|
||||
UnknowModelWarning,
|
||||
ExceptionDefineition,
|
||||
UnsupportedVariableType,
|
||||
UnknownModelWarning,
|
||||
ExceptionDefinition,
|
||||
ExceptionType
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
|
||||
HttpFormData,
|
||||
HttpTimeOutConfig,
|
||||
HttpRetryConfig,
|
||||
HttpErrorDefaultTamplete,
|
||||
HttpErrorDefaultTemplate,
|
||||
HttpErrorHandleConfig
|
||||
)
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
|
||||
try:
|
||||
return config.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
|
||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
||||
def _process_list_variable_literal(self, variable_selector: list) -> str | None:
|
||||
if not self.process_var_selector(".".join(variable_selector)):
|
||||
return None
|
||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
|
||||
var_type = self.variable_type_map(var["type"])
|
||||
if not var_type:
|
||||
self.errors.append(
|
||||
UnsupportVariableType(
|
||||
UnsupportedVariableType(
|
||||
scope=node["id"],
|
||||
name=var["variable"],
|
||||
var_type=var["type"],
|
||||
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
|
||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
|
||||
result = QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
||||
input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||
categories=categories,
|
||||
).model_dump()
|
||||
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
|
||||
def convert_llm_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
||||
context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
|
||||
memory = MemoryWindowSetting(
|
||||
enable=bool(node_data.get("memory")),
|
||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
vision_input = self._process_list_variable_litearl(
|
||||
vision_input = self._process_list_variable_literal(
|
||||
node_data["vision"]["configs"]["variable_selector"]
|
||||
) if vision else None
|
||||
result = LLMNodeConfig.model_construct(
|
||||
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
|
||||
conditions.append(
|
||||
LoopConditionDetail.model_construct(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||
left=self._process_list_variable_literal(condition["variable_selector"]),
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
|
||||
right_input_type = variable["value_type"]
|
||||
right_value_type = self.variable_type_map(variable["var_type"])
|
||||
if right_input_type == ValueInputType.VARIABLE:
|
||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
||||
right_value = self._process_list_variable_literal(variable.get("value", ""))
|
||||
else:
|
||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||
loop_variables.append(
|
||||
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
|
||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
).model_dump()
|
||||
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
|
||||
continue
|
||||
assignments.append(
|
||||
AssignmentItem(
|
||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_litearl(
|
||||
variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_literal(
|
||||
assignment["value"]
|
||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||
operation=self.convert_assignment_operator(assignment["operation"])
|
||||
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
|
||||
input_variables.append(
|
||||
InputVariable.model_construct(
|
||||
name=input_variable["variable"],
|
||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||
variable=self._process_list_variable_literal(input_variable["value_selector"]),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
|
||||
else:
|
||||
if node_data["body"]["data"]:
|
||||
body_content = (node_data["body"]["data"][0].get("value") or
|
||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
||||
self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
|
||||
else:
|
||||
body_content = ""
|
||||
|
||||
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
|
||||
default_header = var["value"]
|
||||
elif var["key"] == "status_code":
|
||||
default_status_code = var["value"]
|
||||
default_value = HttpErrorDefaultTamplete(
|
||||
default_value = HttpErrorDefaultTemplate(
|
||||
body=default_body,
|
||||
headers=default_header,
|
||||
status_code=default_status_code,
|
||||
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
|
||||
for variable in node_data["variables"]:
|
||||
mapping.append(VariablesMappingConfig.model_construct(
|
||||
name=variable["variable"],
|
||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||
value=self._process_list_variable_literal(variable["value_selector"])
|
||||
))
|
||||
result = JinjaRenderNodeConfig.model_construct(
|
||||
template=node_data["template"],
|
||||
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||
))
|
||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
query=self._process_list_variable_literal(node_data["query_variable_selector"]),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
|
||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
result = ParameterExtractorNodeConfig.model_construct(
|
||||
text=self._process_list_variable_litearl(node_data["query"]),
|
||||
text=self._process_list_variable_literal(node_data["query"]),
|
||||
params=params,
|
||||
prompt=node_data.get("instruction")
|
||||
).model_dump()
|
||||
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
|
||||
group_type = {}
|
||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||
group_variables = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
self._process_list_variable_literal(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
group_type["output"] = node_data["output_type"]
|
||||
else:
|
||||
for group in advanced_settings["groups"]:
|
||||
group_variables[group["group_name"]] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
self._process_list_variable_literal(variable)
|
||||
for variable in group["variables"]
|
||||
]
|
||||
group_type[group["group_name"]] = group["output_type"]
|
||||
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def convert_tool_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import (
|
||||
NodeDefinition,
|
||||
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app", {}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
))
|
||||
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
edge = self._convert_edge(edge)
|
||||
if edge:
|
||||
self.edges.append(edge)
|
||||
#
|
||||
|
||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||
con_var = self._convert_variable(variable)
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
#
|
||||
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"y": node["position"]["y"] + position["y"]
|
||||
}
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
detail="parent cycle node not found"
|
||||
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
node_data = node["data"]
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||
try:
|
||||
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
label = None
|
||||
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
label=label,
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}",
|
||||
))
|
||||
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
description=variable.get("description")
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}",
|
||||
|
||||
@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ExceptionDefineition(BaseModel):
|
||||
class ExceptionDefinition(BaseModel):
|
||||
type: ExceptionType
|
||||
detail: str
|
||||
|
||||
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class UnknowModelWarning(ExceptionDefineition):
|
||||
class UnknownModelWarning(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id, node_name, model_name):
|
||||
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
|
||||
)
|
||||
|
||||
|
||||
class UnknowError(ExceptionDefineition):
|
||||
class UnknownError(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.UNKNOWN
|
||||
|
||||
def __init__(self, detail: str, **kwargs):
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
|
||||
|
||||
class UnsupportPlatform(ExceptionDefineition):
|
||||
class UnsupportedPlatform(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.PLATFORM
|
||||
|
||||
def __init__(self, platform: str):
|
||||
super().__init__(detail=f"Unsupport platform {platform}")
|
||||
super().__init__(detail=f"Unsupported platform {platform}")
|
||||
|
||||
|
||||
class UnsupportVariableType(ExceptionDefineition):
|
||||
class UnsupportedVariableType(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.VARIABLE
|
||||
|
||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
|
||||
|
||||
|
||||
class InvalidConfiguration(ExceptionDefineition):
|
||||
class InvalidConfiguration(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.CONFIG
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(detail="Invalid workflow configuration format")
|
||||
|
||||
|
||||
class UnsupportNodeType(ExceptionDefineition):
|
||||
class UnsupportedNodeType(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id: str, node_type: str):
|
||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
||||
super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
self.errors.append(UnsupportedNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Iterable, Callable
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
@@ -41,48 +41,31 @@ class GraphBuilder:
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
stream: bool = False,
|
||||
subgraph: bool = False,
|
||||
cycle: str = '',
|
||||
variable_pool: VariablePool | None = None
|
||||
):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
self.stream = stream
|
||||
self.subgraph = subgraph
|
||||
self.cycle = cycle
|
||||
|
||||
self.start_node_id: str | None = None
|
||||
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.node_map: dict[str, dict] = {}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_activation_dep = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_activation_dep)
|
||||
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
self.variable_pool = VariablePool()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self.add_edges()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self.graph: StateGraph | None = None
|
||||
self.nodes: list = []
|
||||
self.edges: list = []
|
||||
self.reachable_nodes: set[str] | None = None
|
||||
self.end_nodes: list[dict] = []
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._build_reverse_adj()
|
||||
self._analyze_end_node_output()
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("nodes", [])
|
||||
|
||||
@property
|
||||
def edges(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("edges", [])
|
||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
def get_node_type(self, node_id: str) -> str:
|
||||
"""Retrieve the type of node given its ID.
|
||||
@@ -108,13 +91,14 @@ class GraphBuilder:
|
||||
result[node[0]].append(node[1])
|
||||
return result
|
||||
|
||||
def _build_reverse_adj(self):
|
||||
def _build_adj(self):
|
||||
for edge in self.edges:
|
||||
if edge["source"] not in self.reachable_nodes:
|
||||
continue
|
||||
self._reverse_adj[edge.get("target")].append({
|
||||
"id": edge["source"], "branch": edge.get("label")
|
||||
})
|
||||
self._adj[edge.get("source")].append(edge["target"])
|
||||
|
||||
def _find_upstream_activation_dep(
|
||||
self,
|
||||
@@ -302,22 +286,13 @@ class GraphBuilder:
|
||||
"""
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type == NodeType.NOTES:
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
||||
if not self.subgraph:
|
||||
continue
|
||||
|
||||
# Record start and end node IDs
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node_id
|
||||
if node_id not in self.reachable_nodes:
|
||||
continue
|
||||
|
||||
# Create node instance (start and end nodes are also created)
|
||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
|
||||
|
||||
if node_type in BRANCH_NODES:
|
||||
|
||||
@@ -390,6 +365,8 @@ class GraphBuilder:
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
if source not in self.reachable_nodes or target not in self.reachable_nodes:
|
||||
continue
|
||||
condition = edge.get("condition")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
@@ -411,11 +388,12 @@ class GraphBuilder:
|
||||
# Add conditional edges
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
"""Create a router function for each source node that routes to a NOP node for later merging."""
|
||||
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
# NOTE: NOP NODE USED FOR ROUTING ONLY.
|
||||
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
|
||||
return {
|
||||
"activate": {
|
||||
node_id: s["activate"][node_name]
|
||||
@@ -502,14 +480,52 @@ class GraphBuilder:
|
||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||
|
||||
# Connect End nodes to the global END node
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
if end_node_id:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
for node in self.reachable_nodes:
|
||||
if not self._adj[node]:
|
||||
self.graph.add_edge(node, END)
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
nodes = self.workflow_config.get("nodes", [])
|
||||
edges = self.workflow_config.get("edges", [])
|
||||
|
||||
for node in nodes:
|
||||
if (node.get("cycle") or '') == self.cycle:
|
||||
node_type = node.get("type")
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node.get("id")
|
||||
elif node_type == NodeType.NOTES:
|
||||
continue
|
||||
self.nodes.append(node)
|
||||
self.node_map[node.get("id")] = node
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in self.node_map
|
||||
target_in = edge.get("target") in self.node_map
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
self.edges.append(edge)
|
||||
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self._build_adj()
|
||||
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||
maxsize=len(self.nodes)*2
|
||||
)(self._find_upstream_activation_dep)
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges()
|
||||
|
||||
self._analyze_end_node_output()
|
||||
checkpointer = InMemorySaver()
|
||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
||||
return self.graph
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
|
||||
def build_final_output(
|
||||
self,
|
||||
result: dict,
|
||||
execution_context: ExecutionContext,
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
execution_context (ExecutionContext): The execution context containing metadata like
|
||||
execution ID, workspace ID, and user ID.)
|
||||
variable_pool (VariablePool): Variable Pool
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self.aggregate_token_usage(node_outputs)
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
conversation_vars = {}
|
||||
sys_vars = {}
|
||||
|
||||
if variable_pool:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
"sys": variable_pool.get_all_system_vars()
|
||||
"conv": conversation_vars,
|
||||
"sys": sys_vars
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
|
||||
@@ -12,14 +12,29 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
def create(
|
||||
cls,
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
|
||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
"memory_storage_type",
|
||||
"user_rag_memory_id"
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
},
|
||||
memory_storage_type=execution_context.memory_storage_type,
|
||||
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from queue import Queue
|
||||
from collections import deque
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -256,7 +256,7 @@ class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.output_queue: Queue = Queue()
|
||||
self.output_queue: deque[str] = deque()
|
||||
self.processed_outputs = []
|
||||
|
||||
def initialize_end_outputs(
|
||||
@@ -266,7 +266,7 @@ class StreamOutputCoordinator:
|
||||
self.end_outputs = end_node_map
|
||||
self.processed_outputs = []
|
||||
self.activate_end = None
|
||||
self.output_queue = Queue()
|
||||
self.output_queue = deque()
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
@@ -296,13 +296,13 @@ class StreamOutputCoordinator:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
for node in self.end_outputs:
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||
self.output_queue.put(node)
|
||||
self.output_queue.append(node)
|
||||
self.processed_outputs.append(node)
|
||||
if self.activate_end is None and not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
if self.activate_end is None and self.output_queue:
|
||||
self.activate_end = self.output_queue.popleft()
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
@@ -414,8 +414,8 @@ class StreamOutputCoordinator:
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
if not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
if self.output_queue:
|
||||
self.activate_end = self.output_queue.popleft()
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
|
||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -373,6 +373,16 @@ class VariablePool:
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||
if variable_struct is None:
|
||||
return False
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 13:51
|
||||
import datetime
|
||||
import time
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -82,13 +83,15 @@ class WorkflowExecutor:
|
||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||
"""
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||
start_time = time.time()
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
@@ -96,7 +99,8 @@ class WorkflowExecutor:
|
||||
variable_pool=self.variable_pool,
|
||||
execution_id=self.execution_context.execution_id
|
||||
)
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
|
||||
f"cost: {time.time() - start_time:.4f}s")
|
||||
|
||||
return self.graph
|
||||
|
||||
@@ -134,94 +138,12 @@ class WorkflowExecutor:
|
||||
return event.get("data")
|
||||
return self.result_builder.build_final_output(
|
||||
{"error": "Workflow execution did not end as expected"},
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
(datetime.datetime.now() - start).total_seconds(),
|
||||
"",
|
||||
success=False
|
||||
)
|
||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
#
|
||||
# start_time = datetime.datetime.now()
|
||||
#
|
||||
# # Execute the workflow
|
||||
# try:
|
||||
# # Build the workflow graph
|
||||
# graph = self.build_graph()
|
||||
#
|
||||
# # Initialize the variable pool with input data
|
||||
# await self.variable_initializer.initialize(
|
||||
# variable_pool=self.variable_pool,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context
|
||||
# )
|
||||
# initial_state = self.state_manager.create_initial_state(
|
||||
# workflow_config=self.workflow_config,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context,
|
||||
# start_node_id=self.start_node_id
|
||||
# )
|
||||
#
|
||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
#
|
||||
# # Aggregate output from all End nodes
|
||||
# full_content = ''
|
||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
#
|
||||
# # Append messages for user and assistant
|
||||
# if input_data.get("files"):
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("files")
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# else:
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# # Calculate elapsed time
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.info(
|
||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
#
|
||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
#
|
||||
# except Exception as e:
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
# exc_info=True)
|
||||
# return {
|
||||
# "status": "failed",
|
||||
# "error": str(e),
|
||||
# "output": None,
|
||||
# "node_outputs": {},
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "token_usage": None
|
||||
# }
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
@@ -255,7 +177,7 @@ class WorkflowExecutor:
|
||||
"data": {
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"workspace_id": self.execution_context.workspace_id,
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
"conversation_id": self.execution_context.conversation_id,
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
}
|
||||
@@ -376,6 +298,7 @@ class WorkflowExecutor:
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
@@ -396,6 +319,7 @@ class WorkflowExecutor:
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
@@ -409,7 +333,9 @@ async def execute_workflow(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a workflow (convenience function, non-streaming).
|
||||
@@ -420,6 +346,8 @@ async def execute_workflow(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
@@ -427,7 +355,10 @@ async def execute_workflow(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
conversation_id=input_data.get("conversation_id"),
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
@@ -441,7 +372,9 @@ async def execute_workflow_stream(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
"""
|
||||
Execute a workflow in streaming mode (convenience function).
|
||||
@@ -452,6 +385,8 @@ async def execute_workflow_stream(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
@@ -459,7 +394,10 @@ async def execute_workflow_stream(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
conversation_id=input_data.get("conversation_id"),
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
|
||||
@@ -64,9 +64,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
|
||||
|
||||
return release, message
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
|
||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class BaseNode(ABC):
|
||||
All node types should inherit from this class and implement the `execute` method.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
"""Initialize the node.
|
||||
|
||||
Args:
|
||||
@@ -41,6 +41,7 @@ class BaseNode(ABC):
|
||||
self.node_type = node_config["type"]
|
||||
self.cycle = node_config.get("cycle")
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
self.down_stream_nodes = down_stream_nodes
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
@@ -93,18 +94,16 @@ class BaseNode(ABC):
|
||||
dict: A dict with a single key 'activate', mapping node IDs to
|
||||
their activation status (True/False).
|
||||
"""
|
||||
edges = self.workflow_config.get("edges")
|
||||
under_stream_nodes = [
|
||||
edge.get("target")
|
||||
for edge in edges
|
||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
||||
]
|
||||
return {
|
||||
"activate": {
|
||||
node_id: self.check_activate(state)
|
||||
for node_id in under_stream_nodes
|
||||
} | {self.node_id: self.check_activate(state)}
|
||||
}
|
||||
activate_flag = self.check_activate(state)
|
||||
|
||||
if self.node_type not in BRANCH_NODES:
|
||||
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
|
||||
else:
|
||||
activate = {}
|
||||
|
||||
activate[self.node_id] = activate_flag
|
||||
|
||||
return {"activate": activate}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
@@ -315,8 +314,8 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
logger.debug(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
@@ -428,8 +427,8 @@ class BaseNode(ABC):
|
||||
when an error edge exists. If no error edge exists, this method
|
||||
raises an exception to stop the workflow.
|
||||
"""
|
||||
# Check if the node has an error edge defined
|
||||
error_edge = self._find_error_edge()
|
||||
# # Check if the node has an error edge defined
|
||||
# error_edge = self._find_error_edge()
|
||||
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
@@ -447,27 +446,26 @@ class BaseNode(ABC):
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# If an error edge exists, log a warning and continue to error node
|
||||
logger.warning(
|
||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
},
|
||||
"error": error_message,
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# If no error edge, send the error via stream writer and stop the workflow
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
# if error_edge:
|
||||
# # If an error edge exists, log a warning and continue to error node
|
||||
# logger.warning(
|
||||
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
# )
|
||||
# return {
|
||||
# "node_outputs": {
|
||||
# self.node_id: node_output
|
||||
# },
|
||||
# "error": error_message,
|
||||
# "error_node": self.node_id
|
||||
# }
|
||||
# else:
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Extracts the input data for this node (used for logging or audit).
|
||||
@@ -623,7 +621,6 @@ class BaseNode(ABC):
|
||||
async def process_message(
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
@@ -642,10 +639,10 @@ class BaseNode(ABC):
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
if content.content_cache.get(f"{provider}_{api_config.is_omni}"):
|
||||
return content.content_cache[f"{provider}_{api_config.is_omni}"]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
multimodal_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
@@ -654,16 +651,15 @@ class BaseNode(ABC):
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
end_user_id,
|
||||
message = await multimodal_service.process_files(
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
content.content_cache[f"{provider}_{api_config.is_omni}"] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
raise TypeError(f'Unexpected input value type - {type(content)}')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
|
||||
@@ -51,8 +51,8 @@ console.log(result)
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: CodeNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
|
||||
It acts as a container and execution controller for a subgraph.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.child_variable_pool: VariablePool | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
self.workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != self.node_id
|
||||
]
|
||||
self.workflow_config["edges"] = remain_edges
|
||||
# # Update workflow_config by removing cycle nodes and internal edges
|
||||
# self.workflow_config["nodes"] = [
|
||||
# node for node in nodes if node.get("cycle") != self.node_id
|
||||
# ]
|
||||
# self.workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
"edges": self.cycle_edges,
|
||||
},
|
||||
subgraph=True,
|
||||
variable_pool=self.child_variable_pool
|
||||
variable_pool=self.child_variable_pool,
|
||||
cycle=self.node_id
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.child_variable_pool = builder.variable_pool
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If the node type is unsupported.
|
||||
"""
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .config import DocExtractorNodeConfig
|
||||
from .node import DocExtractorNode
|
||||
|
||||
__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"]
|
||||
18
api/app/core/workflow/nodes/document_extractor/config.py
Normal file
18
api/app/core/workflow/nodes/document_extractor/config.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import Field
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class DocExtractorNodeConfig(BaseNodeConfig):
|
||||
file_selector: str = Field(
|
||||
...,
|
||||
description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"file_selector": "{{ sys.files }}"
|
||||
}
|
||||
]
|
||||
}
|
||||
103
api/app/core/workflow/nodes/document_extractor/node.py
Normal file
103
api/app/core/workflow/nodes/document_extractor/node.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||
"""Convert workflow FileObject to multimodal FileInput."""
|
||||
return FileInput(
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=TransferMethod(f.transfer_method),
|
||||
url=f.url or None,
|
||||
upload_file_id=f.file_id or None,
|
||||
file_type=f.origin_file_type or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalise_files(val: Any) -> list[FileObject]:
|
||||
if isinstance(val, FileObject):
|
||||
return [val]
|
||||
if isinstance(val, dict) and val.get("is_file"):
|
||||
return [FileObject(**val)]
|
||||
if isinstance(val, list):
|
||||
result: list[FileObject] = []
|
||||
for item in val:
|
||||
if isinstance(item, FileObject):
|
||||
result.append(item)
|
||||
elif isinstance(item, dict) and item.get("is_file"):
|
||||
result.append(FileObject(**item))
|
||||
else:
|
||||
logger.warning("Ignoring non-file entry in file list for document extractor: %r", item)
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class DocExtractorNode(BaseNode):
|
||||
"""Document Extractor Node.
|
||||
|
||||
Reads one or more file variables and extracts their text content
|
||||
by delegating to MultimodalService._extract_document_text.
|
||||
|
||||
Outputs:
|
||||
text (string) – full concatenated text of all input files
|
||||
chunks (array[string]) – per-file extracted text
|
||||
"""
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"text": VariableType.STRING,
|
||||
"chunks": VariableType.ARRAY_STRING,
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
return business_result
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {"file_selector": self.config.get("file_selector")}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
config = DocExtractorNodeConfig(**self.config)
|
||||
|
||||
raw_val = self.get_variable(config.file_selector, variable_pool, strict=False)
|
||||
if raw_val is None:
|
||||
logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty")
|
||||
return {"text": "", "chunks": []}
|
||||
|
||||
files = _normalise_files(raw_val)
|
||||
if not files:
|
||||
return {"text": "", "chunks": []}
|
||||
|
||||
chunks: list[str] = []
|
||||
with get_db_read() as db:
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
svc = MultimodalService(db)
|
||||
for f in files:
|
||||
try:
|
||||
file_input = _file_object_to_file_input(f)
|
||||
# Ensure URL is populated for local files
|
||||
if not file_input.url:
|
||||
file_input.url = await svc.get_file_url(file_input)
|
||||
# Reuse cached bytes if already fetched
|
||||
if f.get_content():
|
||||
file_input.set_content(f.get_content())
|
||||
text = await svc._extract_document_text(file_input)
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
chunks.append("")
|
||||
|
||||
full_text = "\n\n".join(c for c in chunks if c)
|
||||
logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}")
|
||||
return {"text": full_text, "chunks": chunks}
|
||||
@@ -1,9 +1,7 @@
|
||||
"""End 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -36,8 +36,6 @@ class EndNode(BaseNode):
|
||||
Returns:
|
||||
最终输出字符串
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
|
||||
output = self._render_template(output_template, variable_pool, strict=False)
|
||||
else:
|
||||
output = ""
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
return output
|
||||
|
||||
@@ -23,12 +23,13 @@ class NodeType(StrEnum):
|
||||
BREAK = "break"
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HttpErrorDefaultTamplete(BaseModel):
|
||||
class HttpErrorDefaultTemplate(BaseModel):
|
||||
body: str = Field(
|
||||
default="",
|
||||
description="Default body returned on HTTP error",
|
||||
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
|
||||
description="Error handling strategy: 'none', 'default', or 'branch'",
|
||||
)
|
||||
|
||||
default: HttpErrorDefaultTamplete | None = Field(
|
||||
default: HttpErrorDefaultTemplate | None = Field(
|
||||
default=None,
|
||||
description="Default response template for error handling",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.utils.file_processor import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
|
||||
or a branch identifier string when error branching is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
self.messages = []
|
||||
|
||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -5,14 +6,16 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
@@ -36,19 +39,32 @@ class MemoryReadNode(BaseNode):
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
|
||||
class MemoryWriteNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
@staticmethod
|
||||
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
variables = variable_pattern.findall(content)
|
||||
file_variables = []
|
||||
for variable in variables:
|
||||
if variable_pool.is_file_variable(variable):
|
||||
file_variables.append(variable)
|
||||
for var in file_variables:
|
||||
content = content.replace(var, "")
|
||||
return file_variables, content
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
|
||||
})
|
||||
|
||||
for message in self.typed_config.messages:
|
||||
file_variables, content = self._extract_multimodal_memory_variables(
|
||||
message.content,
|
||||
variable_pool
|
||||
)
|
||||
file_info = []
|
||||
for var in file_variables:
|
||||
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||
if isinstance(instence, FileVariable):
|
||||
file_info.append(FileInput(
|
||||
type=instence.value.type,
|
||||
transfer_method=instence.value.transfer_method,
|
||||
upload_file_id=instence.value.file_id,
|
||||
url=instence.value.url,
|
||||
file_type=instence.value.origin_file_type
|
||||
).model_dump())
|
||||
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||
for file_instence in instence.value:
|
||||
file_info.append(FileInput(
|
||||
type=file_instence.value.type,
|
||||
transfer_method=file_instence.value.transfer_method,
|
||||
upload_file_id=file_instence.value.file_id,
|
||||
url=file_instence.value.url,
|
||||
file_type=file_instence.value.origin_file_type
|
||||
).model_dump())
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(message.content, variable_pool)
|
||||
"content": self._render_template(content, variable_pool),
|
||||
"files": file_info
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
messages,
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,7 +50,8 @@ WorkflowNode = Union[
|
||||
ToolNode,
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode,
|
||||
CodeNode
|
||||
CodeNode,
|
||||
DocExtractorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -81,6 +83,7 @@ class NodeFactory:
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -104,13 +107,15 @@ class NodeFactory:
|
||||
def create_node(
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
workflow_config: dict[str, Any],
|
||||
down_stream_nodes: list[str]
|
||||
) -> WorkflowNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
down_stream_nodes: 下游节点
|
||||
|
||||
Returns:
|
||||
节点实例或 None(对于不支持的节点类型)
|
||||
@@ -127,7 +132,7 @@ class NodeFactory:
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
return node_class(node_config, workflow_config, down_stream_nodes)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> list[str]:
|
||||
|
||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
self.response_metadata = {}
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
"""问题分类器节点"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
self.response_metadata = {}
|
||||
|
||||
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
|
||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化 Start 节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(variable_pool)
|
||||
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"节点 {self.node_id} (Start) 执行完成,"
|
||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||
logger.debug(
|
||||
f"Node {self.node_id} (Start) execution completed, "
|
||||
f"outputting {len(custom_vars)} custom variables"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -153,7 +153,8 @@ class TemplateRenderer:
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
_default_renderer = TemplateRenderer(strict=True)
|
||||
_strict_renderer = TemplateRenderer(strict=True)
|
||||
_lenient_renderer = TemplateRenderer(strict=False)
|
||||
|
||||
|
||||
def render_template(
|
||||
@@ -184,7 +185,7 @@ def render_template(
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
renderer = TemplateRenderer(strict=strict)
|
||||
renderer = _strict_renderer if strict else _lenient_renderer
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
|
||||
Returns:
|
||||
错误列表
|
||||
"""
|
||||
return _default_renderer.validate(template)
|
||||
return _strict_renderer.validate(template)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Union, TYPE_CHECKING
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -119,7 +120,6 @@ class WorkflowValidator:
|
||||
errors = []
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for index, graph in enumerate(graphs):
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
@@ -204,18 +204,18 @@ class WorkflowValidator:
|
||||
Returns:
|
||||
可达节点 ID 集合
|
||||
"""
|
||||
adj = defaultdict(list)
|
||||
for edge in edges:
|
||||
adj[edge["source"]].append(edge["target"])
|
||||
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
queue = deque([start_id])
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
if edge.get("source") == current:
|
||||
target = edge.get("target")
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
current = queue.popleft()
|
||||
for target in adj[current]:
|
||||
if target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
return reachable
|
||||
|
||||
@staticmethod
|
||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
||||
Returns:
|
||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
|
||||
@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
|
||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
|
||||
@@ -2,10 +2,11 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
@@ -26,9 +27,9 @@ class ModelType(StrEnum):
|
||||
RERANK = "rerank"
|
||||
# TTS = "tts"
|
||||
# SPEECH2TEXT = "speech2text"
|
||||
# IMAGE = "image"
|
||||
IMAGE = "image"
|
||||
# AUDIO = "audio"
|
||||
# VISION = "vision"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
@@ -45,6 +46,7 @@ class ModelProvider(StrEnum):
|
||||
XINFERENCE = "xinference"
|
||||
GPUSTACK = "gpustack"
|
||||
BEDROCK = "bedrock"
|
||||
VOLCANO = "volcano"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,21 @@ class Tenants(Base):
|
||||
# 国际化语言配置字段
|
||||
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
|
||||
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
|
||||
|
||||
# 租户联系信息
|
||||
contact_name = Column(String(100), nullable=True) # 联系人姓名
|
||||
contact_email = Column(String(255), nullable=True) # 联系人邮箱
|
||||
contact_phone = Column(String(50), nullable=True) # 联系人电话
|
||||
|
||||
# 租户套餐信息
|
||||
plan = Column(String(50), nullable=True) # 套餐类型
|
||||
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
|
||||
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
|
||||
status = Column(String(50), nullable=True, default='active') # 租户状态
|
||||
|
||||
# 租户功能开关字段
|
||||
feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单")
|
||||
feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单")
|
||||
|
||||
# Relationship to users - one tenant has many users
|
||||
users = relationship("User", back_populates="tenant")
|
||||
|
||||
@@ -9,7 +9,7 @@ class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
username = Column(String, unique=True, index=True, nullable=False)
|
||||
username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from uuid import UUID
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
@@ -190,4 +190,63 @@ class HomePageRepository:
|
||||
|
||||
user_count_dict = {workspace_id: count for workspace_id, count in user_counts}
|
||||
|
||||
return workspaces, app_count_dict, user_count_dict
|
||||
return workspaces, app_count_dict, user_count_dict
|
||||
|
||||
@staticmethod
|
||||
def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库获取版本说明(优先读取已发布的版本)
|
||||
使用反射方式读取表结构,不依赖 premium 模型类
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
version: 版本号,如 "v0.2.7"
|
||||
|
||||
Returns:
|
||||
版本说明字典,格式与 version_info.json 一致
|
||||
如果数据库中没有该版本,返回 None
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import Table, MetaData
|
||||
|
||||
metadata = MetaData()
|
||||
version_notes = Table('version_notes', metadata, autoload_with=db.engine)
|
||||
version_note_items = Table('version_note_items', metadata, autoload_with=db.engine)
|
||||
|
||||
note = db.query(version_notes).filter(
|
||||
version_notes.c.version == version,
|
||||
version_notes.c.is_published == True
|
||||
).first()
|
||||
|
||||
if not note:
|
||||
return None
|
||||
|
||||
items = db.query(version_note_items).filter(
|
||||
version_note_items.c.note_id == note.id
|
||||
).order_by(version_note_items.c.sort_order).all()
|
||||
|
||||
core_upgrades = []
|
||||
for item in items:
|
||||
title = item.title
|
||||
content = item.content
|
||||
if content:
|
||||
core_upgrades.append(f"{title}<br>{content}")
|
||||
else:
|
||||
core_upgrades.append(title)
|
||||
|
||||
return {
|
||||
"introduction": {
|
||||
"codeName": "",
|
||||
"releaseDate": note.release_date.isoformat() if note.release_date else "",
|
||||
"upgradePosition": "",
|
||||
"coreUpgrades": core_upgrades
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "",
|
||||
"releaseDate": note.release_date.isoformat() if note.release_date else "",
|
||||
"upgradePosition": "",
|
||||
"coreUpgrades": core_upgrades
|
||||
}
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
@@ -9,21 +9,22 @@ Classes:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取数据库专用日志器
|
||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -309,57 +310,21 @@ class MemoryConfigRepository:
|
||||
|
||||
Returns:
|
||||
Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
|
||||
stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id)
|
||||
db_config = db.execute(stmt).scalar_one_or_none()
|
||||
if not db_config:
|
||||
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm_id",
|
||||
"embedding_id": "embedding_id",
|
||||
"rerank_id": "rerank_id",
|
||||
# 记忆萃取引擎
|
||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
||||
"deep_retrieval": "deep_retrieval",
|
||||
"t_type_strict": "t_type_strict",
|
||||
"t_name_strict": "t_name_strict",
|
||||
"t_overall": "t_overall",
|
||||
"state": "state",
|
||||
"chunker_strategy": "chunker_strategy",
|
||||
# 句子提取
|
||||
"statement_granularity": "statement_granularity",
|
||||
"include_dialogue_context": "include_dialogue_context",
|
||||
"max_context": "max_context",
|
||||
# 剪枝配置
|
||||
"pruning_enabled": "pruning_enabled",
|
||||
"pruning_scene": "pruning_scene",
|
||||
"pruning_threshold": "pruning_threshold",
|
||||
# 自我反思配置
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
update_data = update.model_dump(exclude_unset=True)
|
||||
update_data.pop("config_id", None)
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
for field, value in update_data.items():
|
||||
setattr(db_config, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
@@ -443,6 +408,9 @@ class MemoryConfigRepository:
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
"rerank_id": db_config.rerank_id,
|
||||
"vision_id": db_config.vision_id,
|
||||
"audio_id": db_config.audio_id,
|
||||
"video_id": db_config.video_id,
|
||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||
"deep_retrieval": db_config.deep_retrieval,
|
||||
@@ -527,7 +495,10 @@ class MemoryConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
||||
def get_config_with_workspace(
|
||||
db: Session,
|
||||
config_id: uuid.UUID | int | str
|
||||
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||
"""Get memory config and its associated workspace information
|
||||
|
||||
Args:
|
||||
@@ -542,8 +513,6 @@ class MemoryConfigRepository:
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
@@ -630,7 +599,7 @@ class MemoryConfigRepository:
|
||||
|
||||
db_logger.debug(
|
||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
return config, workspace
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
@@ -666,7 +635,7 @@ class MemoryConfigRepository:
|
||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
@@ -730,7 +699,7 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 优先查找显式标记为默认的配置
|
||||
stmt = (
|
||||
@@ -742,13 +711,13 @@ class MemoryConfigRepository:
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
|
||||
return config
|
||||
|
||||
|
||||
# 回退:获取最早创建的活跃配置
|
||||
stmt = (
|
||||
select(MemoryConfig)
|
||||
@@ -759,25 +728,25 @@ class MemoryConfigRepository:
|
||||
.order_by(MemoryConfig.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
|
||||
else:
|
||||
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_with_fallback(
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[MemoryConfig]:
|
||||
"""获取记忆配置,支持回退到工作空间默认配置
|
||||
|
||||
@@ -792,19 +761,18 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
if not config_id:
|
||||
db_logger.debug("config_id 为空,使用工作空间默认配置")
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
|
||||
config = db.get(MemoryConfig, config_id)
|
||||
|
||||
|
||||
if config:
|
||||
return config
|
||||
|
||||
|
||||
db_logger.warning(
|
||||
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
from sqlalchemy import and_, or_, func, desc, select
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -137,6 +138,9 @@ class ModelConfigRepository:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.capability:
|
||||
filters.append(ModelConfig.capability.contains(query.capability))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
|
||||
@@ -435,7 +439,6 @@ class ModelConfigRepository:
|
||||
ModelConfig.is_public
|
||||
),
|
||||
ModelConfig.provider == provider,
|
||||
ModelConfig.is_active,
|
||||
~ModelConfig.is_composite
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
@@ -23,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not dialogues:
|
||||
print("No dialogues to save")
|
||||
logger.info("No dialogues to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -48,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating dialogue nodes: {e}")
|
||||
logger.error(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -67,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not statements:
|
||||
print("No statements to save")
|
||||
logger.info("No statements to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -120,13 +125,14 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
logger.error(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
@@ -138,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
List of created chunk UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunk nodes to add")
|
||||
logger.info("No chunk nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -171,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk nodes: {e}")
|
||||
logger.error(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
async def add_memory_summary_nodes(
|
||||
summaries: List[MemorySummaryNode],
|
||||
connector: Neo4jConnector
|
||||
) -> Optional[List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
@@ -191,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
List of created summary node ids or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
print("No memory summary nodes to add")
|
||||
logger.info("No memory summary nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -211,16 +219,14 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_NODE_SAVE,
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
return created_ids
|
||||
except Exception as e:
|
||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ class CommunityRepository:
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
logger.error(f"update_community_metadata failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def batch_update_community_metadata(
|
||||
|
||||
@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
|
||||
# Entity Merge Query
|
||||
MERGE_ENTITIES = """
|
||||
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
||||
@@ -829,9 +828,8 @@ neo4j_query_all = """
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
|
||||
'''针对当前节点下扩长的句子,实体和总结'''
|
||||
Memory_Timeline_ExtractedEntity="""
|
||||
Memory_Timeline_ExtractedEntity = """
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) = $id
|
||||
AND (ms:ExtractedEntity OR ms:MemorySummary)
|
||||
@@ -869,7 +867,7 @@ RETURN
|
||||
|
||||
|
||||
"""
|
||||
Memory_Timeline_MemorySummary="""
|
||||
Memory_Timeline_MemorySummary = """
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) =$id
|
||||
AND (ms:MemorySummary OR ms:ExtractedEntity)
|
||||
@@ -904,7 +902,7 @@ RETURN
|
||||
}
|
||||
) AS statement;
|
||||
"""
|
||||
Memory_Timeline_Statement="""
|
||||
Memory_Timeline_Statement = """
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
|
||||
@@ -947,7 +945,7 @@ RETURN
|
||||
"""
|
||||
|
||||
'''针对当前节点,主要获取更加完整的句子节点'''
|
||||
Memory_Space_Emotion_Statement="""
|
||||
Memory_Space_Emotion_Statement = """
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
RETURN
|
||||
@@ -957,7 +955,7 @@ RETURN
|
||||
n.statement AS statement;
|
||||
|
||||
"""
|
||||
Memory_Space_Emotion_MemorySummary="""
|
||||
Memory_Space_Emotion_MemorySummary = """
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
@@ -970,7 +968,7 @@ RETURN DISTINCT
|
||||
e.emotion_type AS emotion_type,
|
||||
e.statement AS statement;
|
||||
"""
|
||||
Memory_Space_Emotion_ExtractedEntity="""
|
||||
Memory_Space_Emotion_ExtractedEntity = """
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
@@ -985,18 +983,18 @@ RETURN DISTINCT
|
||||
|
||||
'''获取实体'''
|
||||
|
||||
Memory_Space_User="""
|
||||
Memory_Space_User = """
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||
return DISTINCT elementId(m) as id
|
||||
"""
|
||||
Memory_Space_Entity="""
|
||||
Memory_Space_Entity = """
|
||||
MATCH (n)-[]-(m)
|
||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||
RETURN
|
||||
DISTINCT m.name as name,m.end_user_id as end_user_id
|
||||
"""
|
||||
Memory_Space_Associative="""
|
||||
Memory_Space_Associative = """
|
||||
MATCH (u)-[]-(x)-[]-(h)
|
||||
WHERE elementId(u) = $user_id
|
||||
AND elementId(h) = $id
|
||||
@@ -1005,61 +1003,69 @@ RETURN DISTINCT
|
||||
"""
|
||||
|
||||
Graph_Node_query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
UNION ALL
|
||||
MATCH (n:Perceptual)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
4 AS priority
|
||||
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||
@@ -1069,6 +1075,7 @@ Graph_Node_query = """
|
||||
|
||||
COMMUNITY_NODE_UPSERT = """
|
||||
MERGE (c:Community {community_id: $community_id})
|
||||
ON CREATE SET c.id = $community_id
|
||||
SET c.end_user_id = $end_user_id,
|
||||
c.member_count = $member_count,
|
||||
c.updated_at = datetime()
|
||||
@@ -1175,7 +1182,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
SET c.id = coalesce(c.id, $community_id),
|
||||
c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.summary_embedding = $summary_embedding,
|
||||
@@ -1186,7 +1194,8 @@ RETURN c.community_id AS community_id
|
||||
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||
UNWIND $communities AS row
|
||||
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||
SET c.name = row.name,
|
||||
SET c.id = coalesce(c.id, row.community_id),
|
||||
c.name = row.name,
|
||||
c.summary = row.summary,
|
||||
c.core_entities = row.core_entities,
|
||||
c.summary_embedding = row.summary_embedding,
|
||||
@@ -1270,6 +1279,40 @@ RETURN
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL AND
|
||||
c.summary_embedding IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.name = ''
|
||||
OR c.summary IS NULL OR c.summary = ''
|
||||
OR c.core_entities IS NULL
|
||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
@@ -1327,37 +1370,35 @@ ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL
|
||||
) AS is_complete
|
||||
# 感知记忆节点保存
|
||||
PERCEPTUAL_NODE_SAVE = """
|
||||
UNWIND $perceptuals AS p
|
||||
MERGE (n:Perceptual {id: p.id})
|
||||
SET n += {
|
||||
id: p.id,
|
||||
end_user_id: p.end_user_id,
|
||||
perceptual_type: p.perceptual_type,
|
||||
file_path: p.file_path,
|
||||
file_name: p.file_name,
|
||||
file_ext: p.file_ext,
|
||||
summary: p.summary,
|
||||
keywords: p.keywords,
|
||||
topic: p.topic,
|
||||
domain: p.domain,
|
||||
created_at: p.created_at,
|
||||
file_type: p.file_type,
|
||||
summary_embedding: p.summary_embedding
|
||||
}
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL AND
|
||||
c.summary_embedding IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.name = ''
|
||||
OR c.summary IS NULL OR c.summary = ''
|
||||
OR c.core_entities IS NULL
|
||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||
RETURN c.community_id AS community_id
|
||||
# 感知记忆与对话的关联边
|
||||
PERCEPTUAL_CHUNK_EDGE_SAVE = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
|
||||
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
|
||||
ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
PerceptualNode,
|
||||
PerceptualEdge,
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save entities and their relationships using graph models"""
|
||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
|
||||
|
||||
|
||||
async def save_chunk_nodes(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save chunk nodes using graph models"""
|
||||
if not chunk_nodes:
|
||||
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
|
||||
|
||||
|
||||
async def save_statement_chunk_edges(
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-chunk edges using graph models"""
|
||||
if not statement_chunk_edges:
|
||||
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
|
||||
|
||||
|
||||
async def save_statement_entity_edges(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-entity edges using graph models"""
|
||||
if not statement_entity_edges:
|
||||
@@ -142,7 +147,7 @@ async def save_statement_entity_edges(
|
||||
if all_se_edges:
|
||||
try:
|
||||
await connector.execute_query(
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
relationships=all_se_edges
|
||||
)
|
||||
except Exception:
|
||||
@@ -154,24 +159,28 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
perceptual_nodes: List[PerceptualNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
perceptual_edges: List[PerceptualEdge],
|
||||
connector: Neo4jConnector,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||
schedule_clustering_after_write() 显式触发。
|
||||
_trigger_clustering_sync() 显式触发。
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
statement_nodes: List of StatementNode objects to save
|
||||
entity_nodes: List of ExtractedEntityNode objects to save
|
||||
perceptual_nodes: List of PerceptualNode objects to save
|
||||
entity_edges: List of EntityEntityEdge objects to save
|
||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||
perceptual_edges: List of PerceptualEdge objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||
dialogue_uuids = [record["uuid"] async for record in result]
|
||||
results['dialogues'] = dialogue_uuids
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
|
||||
# 2. Save all chunk nodes in batch
|
||||
if chunk_nodes:
|
||||
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['chunks'] = chunk_uuids
|
||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
|
||||
if perceptual_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
|
||||
perceptual_data = [node.model_dump() for node in perceptual_nodes]
|
||||
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
|
||||
perceptual_uuids = [record["uuid"] async for record in result]
|
||||
results["perceptuals"] = perceptual_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
|
||||
|
||||
# 3. Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['statement_entity_edges'] = se_uuids
|
||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||
|
||||
if perceptual_edges:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
|
||||
perceptual_edge_data = []
|
||||
for edge in perceptual_edges:
|
||||
print(edge.source, edge.target)
|
||||
perceptual_edge_data.append({
|
||||
"perceptual_id": edge.source,
|
||||
"chunk_id": edge.target,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
})
|
||||
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
|
||||
perceptual_edges_uuids = [record["uuid"] async for record in result]
|
||||
results['perceptual_chunk_edges'] = perceptual_edges_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
|
||||
|
||||
return results
|
||||
|
||||
try:
|
||||
@@ -303,16 +336,13 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
async def _trigger_clustering_sync(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
@@ -324,15 +354,16 @@ def schedule_clustering_after_write(
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id)
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
|
||||
@@ -196,6 +196,13 @@ class CitationConfig(BaseModel):
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
document_id: str
|
||||
file_name: str
|
||||
knowledge_id: str
|
||||
score: float
|
||||
|
||||
|
||||
class WebSearchConfig(BaseModel):
|
||||
"""联网搜索配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
||||
|
||||
rerank_model_id: Optional[UUID] = None
|
||||
rerank_model_name: Optional[str] = None
|
||||
video_model_id: Optional[UUID] = None
|
||||
video_model_name: Optional[str] = None
|
||||
vision_model_id: Optional[UUID] = None
|
||||
vision_model_name: Optional[str] = None
|
||||
audio_model_id: Optional[UUID] = None
|
||||
audio_model_name: Optional[str] = None
|
||||
|
||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -8,9 +8,6 @@ import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 从 json_schema.py 迁移的 Schema
|
||||
# ============================================================================
|
||||
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
data: List[BaseDataSchema] = Field(...,
|
||||
description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
|
||||
description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None,
|
||||
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_data(cls, v):
|
||||
@@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel):
|
||||
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
|
||||
"""
|
||||
field: List[Dict[str, Any]] = Field(
|
||||
...,
|
||||
...,
|
||||
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||
)
|
||||
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
|
||||
description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None,
|
||||
description="List of detailed change records with IDs and field information.")
|
||||
|
||||
|
||||
class SingleReflexionResultSchema(BaseModel):
|
||||
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||
type: str = Field("reflexion_result", description="The type identifier.")
|
||||
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
results: List[SingleReflexionResultSchema] = Field(...,
|
||||
description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
|
||||
# Composite key identifying a config row
|
||||
class ConfigKey(BaseModel): # 配置参数键模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
user_id: str | None = Field(default=None, description="用户标识(字符串)")
|
||||
apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
|
||||
|
||||
|
||||
# Allowed chunking strategies (extendable later)
|
||||
@@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
||||
|
||||
|
||||
# 本体场景关联(可选)
|
||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
||||
|
||||
|
||||
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入)
|
||||
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
|
||||
|
||||
|
||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
@@ -255,8 +262,11 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
|
||||
|
||||
|
||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id:Union[uuid.UUID, int, str] = None
|
||||
config_id: Union[uuid.UUID, int, str] = None
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
audio_id: Optional[str] = Field(None, description="语音模型ID")
|
||||
vision_id: Optional[str] = Field(None, description="视觉模型ID")
|
||||
video_id: Optional[str] = Field(None, description="视频模型ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
enable_llm_dedup_blockwise: Optional[bool] = None
|
||||
@@ -322,14 +332,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
||||
|
||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||
# 遗忘引擎配置参数更新模型
|
||||
config_id:Union[uuid.UUID, int, str] = None
|
||||
config_id: Union[uuid.UUID, int, str] = None
|
||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||
|
||||
|
||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -364,11 +374,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
|
||||
|
||||
|
||||
def fail(
|
||||
msg: str,
|
||||
error_code: str = "ERROR",
|
||||
data: Optional[Any] = None,
|
||||
time: Optional[int] = None,
|
||||
query_preview: Optional[str] = None,
|
||||
msg: str,
|
||||
error_code: str = "ERROR",
|
||||
data: Optional[Any] = None,
|
||||
time: Optional[int] = None,
|
||||
query_preview: Optional[str] = None,
|
||||
) -> ApiResponse:
|
||||
payload = data
|
||||
if query_preview is not None:
|
||||
@@ -387,12 +397,13 @@ def fail(
|
||||
time=time or _now_ms(),
|
||||
)
|
||||
|
||||
|
||||
class GenerateCacheRequest(BaseModel):
|
||||
"""缓存生成请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
end_user_id: Optional[str] = Field(
|
||||
None,
|
||||
None,
|
||||
description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成"
|
||||
)
|
||||
|
||||
@@ -404,7 +415,7 @@ class GenerateCacheRequest(BaseModel):
|
||||
class ForgettingTriggerRequest(BaseModel):
|
||||
"""手动触发遗忘周期请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)")
|
||||
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)")
|
||||
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)")
|
||||
@@ -413,7 +424,7 @@ class ForgettingTriggerRequest(BaseModel):
|
||||
class ForgettingConfigResponse(BaseModel):
|
||||
"""遗忘引擎配置响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||
decay_constant: float = Field(..., description="衰减常数 d")
|
||||
lambda_time: float = Field(..., description="时间衰减参数")
|
||||
@@ -432,7 +443,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
||||
"""遗忘引擎配置更新请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||
@@ -448,7 +459,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
||||
class ForgettingCycleHistoryPoint(BaseModel):
|
||||
"""遗忘周期历史数据点模型(用于趋势图)"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
date: str = Field(..., description="日期(格式: '1/1', '1/2')")
|
||||
merged_count: int = Field(..., description="每日融合节点数")
|
||||
average_activation: Optional[float] = Field(None, description="平均激活值")
|
||||
@@ -459,7 +470,7 @@ class ForgettingCycleHistoryPoint(BaseModel):
|
||||
class PendingForgettingNode(BaseModel):
|
||||
"""待遗忘节点模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
node_id: str = Field(..., description="节点ID")
|
||||
node_type: str = Field(..., description="节点类型:statement/entity/summary")
|
||||
content_summary: str = Field(..., description="内容摘要")
|
||||
@@ -472,7 +483,8 @@ class ForgettingStatsResponse(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
@@ -480,7 +492,7 @@ class ForgettingStatsResponse(BaseModel):
|
||||
class ForgettingReportResponse(BaseModel):
|
||||
"""遗忘周期报告响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
merged_count: int = Field(..., description="融合的节点对数量")
|
||||
nodes_before: int = Field(..., description="遗忘前的节点总数")
|
||||
nodes_after: int = Field(..., description="遗忘后的节点总数")
|
||||
@@ -495,7 +507,7 @@ class ForgettingReportResponse(BaseModel):
|
||||
class ForgettingCurvePoint(BaseModel):
|
||||
"""遗忘曲线数据点模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
day: int = Field(..., description="天数")
|
||||
activation: float = Field(..., description="激活值")
|
||||
retention_rate: float = Field(..., description="保持率(与激活值相同)")
|
||||
@@ -504,7 +516,7 @@ class ForgettingCurvePoint(BaseModel):
|
||||
class ForgettingCurveRequest(BaseModel):
|
||||
"""遗忘曲线请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||
@@ -513,6 +525,6 @@ class ForgettingCurveRequest(BaseModel):
|
||||
class ForgettingCurveResponse(BaseModel):
|
||||
"""遗忘曲线响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表")
|
||||
config: Dict[str, Any] = Field(..., description="使用的配置参数")
|
||||
|
||||
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
|
||||
updated_at: datetime.datetime
|
||||
api_keys: List["ModelApiKey"] = []
|
||||
|
||||
@staticmethod
|
||||
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||||
if not key or len(key) <= prefix + suffix:
|
||||
return "*" * len(key)
|
||||
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||||
|
||||
@field_validator("api_keys", mode="after")
|
||||
@classmethod
|
||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
|
||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("api_keys", when_used="json")
|
||||
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||||
result = []
|
||||
for api_key in api_keys:
|
||||
data = api_key.model_dump()
|
||||
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
||||
self.model_config_ids = [
|
||||
mc.id for mc in self.model_configs
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
]
|
||||
# 情况2:字典列表
|
||||
elif isinstance(self.model_configs, list):
|
||||
self.model_config_ids = [
|
||||
mc['id'] if isinstance(mc, dict) else mc.id
|
||||
for mc in self.model_configs
|
||||
if ((isinstance(mc, dict)
|
||||
and 'id' in mc
|
||||
if ((isinstance(mc, dict)
|
||||
and 'id' in mc
|
||||
and not mc.get('is_composite', False)
|
||||
and mc.get('name') == self.model_name) or
|
||||
(hasattr(mc, 'id')
|
||||
and mc.get('name') == self.model_name) or
|
||||
(hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name))
|
||||
]
|
||||
@@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
validate_assignment=True # 确保赋值触发校验
|
||||
)
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
|
||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelMarketplace(BaseModel):
|
||||
"""模型广场响应Schema"""
|
||||
llm_models: List[ModelConfig] = []
|
||||
@@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel):
|
||||
class ModelBase(BaseModel):
|
||||
"""基础模型Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
type: str
|
||||
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -82,6 +82,12 @@ class AppChatService:
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# opening_statement:首轮对话注入开场白
|
||||
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
|
||||
system_prompt = self.agent_service._inject_opening_statement(
|
||||
features_config, system_prompt, is_new_conversation
|
||||
)
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
@@ -93,7 +99,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||
@@ -129,45 +136,18 @@ class AppChatService:
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
messages = self.conversation_service.get_messages(
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
limit=10
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
history = []
|
||||
for msg in messages:
|
||||
content = [{"type": "text", "text": msg.content}]
|
||||
|
||||
# 处理 meta_data 中的 files
|
||||
if msg.meta_data and msg.meta_data.get("files"):
|
||||
files = msg.meta_data.get("files", [])
|
||||
# 使用 MultimodalService 处理文件
|
||||
multimodal_service = MultimodalService(self.db, api_config=model_info)
|
||||
|
||||
# 将 files 转换为 FileInput 格式
|
||||
file_inputs = []
|
||||
for file in files:
|
||||
from app.schemas.app_schema import FileInput, TransferMethod
|
||||
file_input = FileInput(
|
||||
type=file.get("type"),
|
||||
transfer_method=TransferMethod.REMOTE_URL,
|
||||
url=file.get("url")
|
||||
)
|
||||
file_inputs.append(file_input)
|
||||
|
||||
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||
|
||||
content.extend(history_processed_files)
|
||||
|
||||
history.append({
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
})
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -206,7 +186,8 @@ class AppChatService:
|
||||
|
||||
# 构建用户消息内容(含多模态文件)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -221,6 +202,13 @@ class AppChatService:
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
@@ -249,8 +237,9 @@ class AppChatService:
|
||||
}),
|
||||
"elapsed_time": elapsed_time,
|
||||
"suggested_questions": suggested_questions,
|
||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||
"citations": self.agent_service._filter_citations(features_config, citations_collector),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
async def agnet_chat_stream(
|
||||
@@ -301,6 +290,12 @@ class AppChatService:
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# opening_statement:首轮对话注入开场白
|
||||
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
|
||||
system_prompt = self.agent_service._inject_opening_statement(
|
||||
features_config, system_prompt, is_new_conversation
|
||||
)
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
@@ -313,7 +308,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -350,45 +346,18 @@ class AppChatService:
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
messages = self.conversation_service.get_messages(
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
limit=10
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
history = []
|
||||
for msg in messages:
|
||||
content = [{"type": "text", "text": msg.content}]
|
||||
|
||||
# 处理 meta_data 中的 files
|
||||
if msg.meta_data and msg.meta_data.get("files"):
|
||||
history_files = msg.meta_data.get("files", [])
|
||||
# 使用 MultimodalService 处理文件
|
||||
multimodal_service = MultimodalService(self.db, api_config=model_info)
|
||||
|
||||
# 将 files 转换为 FileInput 格式
|
||||
file_inputs = []
|
||||
for file in history_files:
|
||||
from app.schemas.app_schema import FileInput, TransferMethod
|
||||
file_input = FileInput(
|
||||
type=file.get("type"),
|
||||
transfer_method=TransferMethod.REMOTE_URL,
|
||||
url=file.get("url")
|
||||
)
|
||||
file_inputs.append(file_input)
|
||||
|
||||
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||
|
||||
content.extend(history_processed_files)
|
||||
|
||||
history.append({
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
})
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
@@ -433,7 +402,7 @@ class AppChatService:
|
||||
elapsed_time = time.time() - start_time
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations)
|
||||
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||
@@ -443,11 +412,23 @@ class AppChatService:
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector)
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[]
|
||||
"files":[],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -457,11 +438,16 @@ class AppChatService:
|
||||
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
@@ -1638,7 +1638,7 @@ class AppService:
|
||||
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def _extract_memory_config_id(
|
||||
def _get_memory_config_id_from_release(
|
||||
self,
|
||||
app_type: str,
|
||||
config: Dict[str, Any]
|
||||
@@ -1863,7 +1863,7 @@ class AppService:
|
||||
self.db.flush() # 先 flush,确保 release 已插入数据库
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config)
|
||||
memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(app.type, config)
|
||||
|
||||
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
|
||||
if is_legacy_int and not memory_config_id:
|
||||
@@ -2001,7 +2001,7 @@ class AppService:
|
||||
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config)
|
||||
memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(release.type, release.config)
|
||||
|
||||
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
|
||||
if is_legacy_int and not memory_config_id:
|
||||
|
||||
@@ -274,7 +274,8 @@ class ConversationService:
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None,
|
||||
api_config: Optional[ModelInfo] = None
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve historical conversation messages formatted as dictionaries.
|
||||
@@ -282,7 +283,8 @@ class ConversationService:
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||
api_config (Optional[ModelInfo]): Model API configuration for multimodal processing.
|
||||
current_provider (Optional[str]): Current provider for file handling.
|
||||
current_is_omni (Optional[bool]): Current omni flag for file handling.
|
||||
|
||||
Returns:
|
||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||
@@ -292,38 +294,30 @@ class ConversationService:
|
||||
limit=max_history
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
history = []
|
||||
for msg in messages:
|
||||
content = [{"type": "text", "text": msg.content}]
|
||||
|
||||
# 处理 meta_data 中的 files
|
||||
if msg.meta_data and msg.meta_data.get("files"):
|
||||
files = msg.meta_data.get("files", [])
|
||||
if api_config:
|
||||
# 使用 MultimodalService 处理文件
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
multimodal_service = MultimodalService(self.db, api_config=api_config)
|
||||
|
||||
# 将 files 转换为 FileInput 格式
|
||||
file_inputs = []
|
||||
for file in files:
|
||||
from app.schemas.app_schema import FileInput, TransferMethod
|
||||
file_input = FileInput(
|
||||
type=file.get("type"),
|
||||
transfer_method=TransferMethod.REMOTE_URL,
|
||||
url=file.get("url")
|
||||
)
|
||||
file_inputs.append(file_input)
|
||||
|
||||
processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||
|
||||
content.extend(processed_files)
|
||||
|
||||
history.append({
|
||||
msg_dict = {
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
})
|
||||
"content": [{"type": "text", "text": msg.content}]
|
||||
}
|
||||
|
||||
# 处理用户消息中的多模态文件
|
||||
if msg.role == "user" and msg.meta_data:
|
||||
history_files = msg.meta_data.get("history_files", {})
|
||||
|
||||
if history_files and current_provider and current_is_omni is not None:
|
||||
# 检查是否需要重新处理文件
|
||||
stored_provider = history_files.get("provider")
|
||||
stored_is_omni = history_files.get("is_omni")
|
||||
|
||||
# 如果provider或is_omni不匹配,需要重新处理
|
||||
if stored_provider != current_provider or stored_is_omni != current_is_omni:
|
||||
continue
|
||||
|
||||
# provider和is_omni匹配,直接使用存储的内容
|
||||
msg_dict["content"].extend(history_files.get("content"))
|
||||
|
||||
history.append(msg_dict)
|
||||
|
||||
return history
|
||||
|
||||
@@ -539,6 +533,7 @@ class ConversationService:
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -546,7 +541,8 @@ class ConversationService:
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
@@ -554,15 +550,8 @@ class ConversationService:
|
||||
conversation_messages = await self.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=20,
|
||||
api_config=ModelInfo(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
capability=api_config.capability,
|
||||
is_omni=api_config.is_omni,
|
||||
model_type=model_type
|
||||
)
|
||||
current_provider=provider,
|
||||
current_is_omni=is_omni
|
||||
)
|
||||
if len(conversation_messages) == 0:
|
||||
return ConversationOut(
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
@@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
|
||||
return web_search_tool
|
||||
|
||||
|
||||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
|
||||
Args:
|
||||
kb_config: 知识库配置
|
||||
kb_ids: 知识库ID列表
|
||||
user_id: 用户ID
|
||||
citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充)
|
||||
列表元素类型为 Citation,包含字段:
|
||||
- document_id: 文档唯一标识
|
||||
- file_name: 文件名
|
||||
- knowledge_id: 知识库 ID
|
||||
- score: 检索相关性得分
|
||||
|
||||
Returns:
|
||||
检索到的相关知识内容
|
||||
@@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
}
|
||||
)
|
||||
|
||||
# 收集引用信息
|
||||
if citations_collector is not None:
|
||||
seen_doc_ids = {c.get("document_id") for c in citations_collector}
|
||||
for chunk in retrieve_chunks_result:
|
||||
meta = chunk.metadata or {}
|
||||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||
if doc_id and doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
citations_collector.append(Citation(
|
||||
document_id=doc_id,
|
||||
file_name=meta.get("file_name", ""),
|
||||
knowledge_id=str(meta.get("knowledge_id", "")),
|
||||
score=meta.get("score", 0)
|
||||
))
|
||||
|
||||
return f"检索到以下相关信息:\n\n{context}"
|
||||
else:
|
||||
logger.warning("知识库检索未找到结果")
|
||||
@@ -320,26 +341,26 @@ class AgentRunService:
|
||||
self,
|
||||
knowledge_retrieval_config: dict | None,
|
||||
user_id
|
||||
) -> list:
|
||||
) -> tuple[list, list]:
|
||||
"""返回 (tools, citations_collector)"""
|
||||
if not knowledge_retrieval_config:
|
||||
return []
|
||||
return [], []
|
||||
|
||||
citations_collector = []
|
||||
tools = []
|
||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
||||
kb_tool = create_knowledge_retrieval_tool(
|
||||
knowledge_retrieval_config, kb_ids, user_id,
|
||||
citations_collector=citations_collector
|
||||
)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
extra={"kb_ids": kb_ids, "tool_count": len(tools)}
|
||||
)
|
||||
return tools
|
||||
return tools, citations_collector
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
@@ -441,12 +462,12 @@ class AgentRunService:
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
features_config: Dict[str, Any],
|
||||
citations: List[Any]
|
||||
citations: List[Citation]
|
||||
) -> List[Any]:
|
||||
"""根据 citation 开关决定是否返回引用来源"""
|
||||
citation_cfg = features_config.get("citation", {})
|
||||
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
||||
return citations
|
||||
return [cit.model_dump() for cit in citations]
|
||||
return []
|
||||
|
||||
async def run(
|
||||
@@ -549,7 +570,8 @@ class AgentRunService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -592,8 +614,9 @@ class AgentRunService:
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
api_config=model_info,
|
||||
max_history=10
|
||||
max_history=10,
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
@@ -602,7 +625,7 @@ class AgentRunService:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -661,7 +684,10 @@ class AgentRunService:
|
||||
})
|
||||
},
|
||||
files=files,
|
||||
audio_url=audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
response = {
|
||||
@@ -676,8 +702,9 @@ class AgentRunService:
|
||||
"suggested_questions": await self._generate_suggested_questions(
|
||||
features_config, result["content"], api_key_config, effective_params
|
||||
) if not sub_agent else [],
|
||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||
"citations": self._filter_citations(features_config, citations_collector),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
logger.info(
|
||||
@@ -785,7 +812,8 @@ class AgentRunService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||||
tools.extend(kb_tools)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -830,8 +858,9 @@ class AgentRunService:
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
api_config=model_info,
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10),
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
@@ -840,7 +869,7 @@ class AgentRunService:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -909,10 +938,13 @@ class AgentRunService:
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
},
|
||||
files=files,
|
||||
audio_url=stream_audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=stream_audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 12. 发送结束事件(包含 suggested_questions 和 tts)
|
||||
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
|
||||
end_data: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -923,7 +955,18 @@ class AgentRunService:
|
||||
features_config, full_content, api_key_config, effective_params
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
end_data["citations"] = self._filter_citations(features_config, [])
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self._filter_citations(features_config, citations_collector)
|
||||
yield self._format_sse_event("end", end_data)
|
||||
|
||||
logger.info(
|
||||
@@ -1119,14 +1162,17 @@ class AgentRunService:
|
||||
async def _load_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
api_config: ModelInfo | None = None,
|
||||
max_history: int = 10
|
||||
max_history: int = 10,
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""加载会话历史消息
|
||||
"""加载会话历史消息,并根据当前模型配置处理多模态文件
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
current_provider: 当前模型的provider
|
||||
current_is_omni: 当前模型的is_omni
|
||||
|
||||
Returns:
|
||||
List[Dict]: 历史消息列表
|
||||
@@ -1138,7 +1184,8 @@ class AgentRunService:
|
||||
history = await conversation_service.get_conversation_history(
|
||||
conversation_id=uuid.UUID(conversation_id),
|
||||
max_history=max_history,
|
||||
api_config=api_config
|
||||
current_provider=current_provider,
|
||||
current_is_omni=current_is_omni
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1166,7 +1213,10 @@ class AgentRunService:
|
||||
app_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None,
|
||||
audio_url: Optional[str] = None
|
||||
processed_files: Optional[List[Dict[str, Any]]] = None,
|
||||
audio_url: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
is_omni: Optional[bool] = None
|
||||
) -> None:
|
||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||
|
||||
@@ -1177,6 +1227,11 @@ class AgentRunService:
|
||||
app_id: 应用ID(未使用,保留用于兼容性)
|
||||
user_id: 用户ID(未使用,保留用于兼容性)
|
||||
meta_data: token消耗
|
||||
files: 原始文件输入
|
||||
processed_files: 处理后的文件
|
||||
audio_url: 音频URL
|
||||
provider: 模型供应商
|
||||
is_omni: 是否为全模态模型
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -1186,15 +1241,24 @@ class AgentRunService:
|
||||
|
||||
# 保存消息(会话已经存在)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
# 保存 history_files,包含 provider 和 is_omni 信息
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": provider,
|
||||
"is_omni": is_omni
|
||||
}
|
||||
|
||||
# 保存用户消息
|
||||
conversation_service.add_message(
|
||||
conversation_id=conv_uuid,
|
||||
@@ -1420,8 +1484,9 @@ class AgentRunService:
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||
"""文本流式输入并行合成音频。
|
||||
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
||||
返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
|
||||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
|
||||
"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
@@ -1808,6 +1873,7 @@ class AgentRunService:
|
||||
),
|
||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||
"audio_url": result.get("audio_url"),
|
||||
"audio_status": result.get("audio_status"),
|
||||
"citations": result.get("citations", []),
|
||||
"suggested_questions": result.get("suggested_questions", []),
|
||||
"error": None
|
||||
@@ -1885,6 +1951,7 @@ class AgentRunService:
|
||||
"results": [{
|
||||
**r,
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
} for r in results],
|
||||
@@ -2016,6 +2083,7 @@ class AgentRunService:
|
||||
full_content = ""
|
||||
returned_conversation_id = model_conversation_id
|
||||
audio_url = None
|
||||
audio_status = None
|
||||
citations = []
|
||||
suggested_questions = []
|
||||
|
||||
@@ -2074,6 +2142,7 @@ class AgentRunService:
|
||||
# 从 end 事件中提取 features 输出字段
|
||||
if event_type == "end" and event_data:
|
||||
audio_url = event_data.get("audio_url")
|
||||
audio_status = event_data.get("audio_status")
|
||||
citations = event_data.get("citations", [])
|
||||
suggested_questions = event_data.get("suggested_questions", [])
|
||||
|
||||
@@ -2103,6 +2172,7 @@ class AgentRunService:
|
||||
"message": full_content,
|
||||
"elapsed_time": elapsed,
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"error": None
|
||||
@@ -2117,6 +2187,7 @@ class AgentRunService:
|
||||
"elapsed_time": elapsed,
|
||||
"message_length": len(full_content),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"timestamp": time.time()
|
||||
@@ -2253,6 +2324,7 @@ class AgentRunService:
|
||||
"message": r.get("message"),
|
||||
"elapsed_time": r.get("elapsed_time", 0),
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
"error": r.get("error")
|
||||
|
||||
@@ -325,27 +325,30 @@ class FileStorageService:
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_file_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
async def get_file_url(
|
||||
self,
|
||||
file_key: str,
|
||||
expires: int = 3600,
|
||||
file_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get an access URL for a file.
|
||||
|
||||
Args:
|
||||
file_key: The file key.
|
||||
expires: URL validity period in seconds (default: 1 hour).
|
||||
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||
|
||||
Returns:
|
||||
URL for accessing the file.
|
||||
"""
|
||||
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
|
||||
|
||||
try:
|
||||
url = await self.storage.get_url(file_key, expires)
|
||||
url = await self.storage.get_url(file_key, expires, file_name=file_name)
|
||||
logger.debug(f"File URL generated: file_key={file_key}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting file URL: file_key={file_key}, error={str(e)}"
|
||||
)
|
||||
logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
162
api/app/services/generation_service.py
Normal file
162
api/app/services/generation_service.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
图片和视频生成服务
|
||||
|
||||
提供统一的生成接口,支持多种 Provider
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
|
||||
class GenerationService:
|
||||
"""生成服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
model_config_id: str,
|
||||
prompt: str,
|
||||
size: Optional[str] = "2k",
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成图片
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
prompt: 提示词
|
||||
size: 图片尺寸
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
if model_config.type != ModelType.IMAGE:
|
||||
raise BusinessException(
|
||||
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 生成图片
|
||||
generator = RedBearImageGenerator(config)
|
||||
result = await generator.agenerate(prompt, size, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
model_config_id: str,
|
||||
prompt: str,
|
||||
duration: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成视频
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
prompt: 提示词
|
||||
duration: 视频时长(秒)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果(包含任务ID)
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
if model_config.type != ModelType.VIDEO:
|
||||
raise BusinessException(
|
||||
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 生成视频
|
||||
generator = RedBearVideoGenerator(config)
|
||||
result = await generator.agenerate(prompt, duration, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
async def get_video_task_status(
|
||||
self,
|
||||
model_config_id: str,
|
||||
task_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务状态
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 查询任务状态
|
||||
generator = RedBearVideoGenerator(config)
|
||||
result = await generator.aget_task_status(task_id)
|
||||
|
||||
return result
|
||||
@@ -94,29 +94,38 @@ class HomePageService:
|
||||
@staticmethod
|
||||
def load_version_introduction(version: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从 JSON 文件加载对应版本的介绍
|
||||
加载对应版本的介绍(优先从数据库读取,fallback 到 JSON 文件)
|
||||
:param version: 系统版本号(如 "0.2.0")
|
||||
:return: 对应版本的详细介绍
|
||||
"""
|
||||
# 2. 定义 JSON 文件路径(简化路径处理,保留绝对路径调试特性)
|
||||
from copy import deepcopy
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.home_page_repository import HomePageRepository
|
||||
|
||||
result = deepcopy(HomePageService.DEFAULT_RETURN_DATA)
|
||||
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_result = HomePageRepository.get_version_introduction(db, version)
|
||||
if db_result:
|
||||
return db_result
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
json_abs_path = Path(__file__).parent.parent / "version_info.json"
|
||||
json_abs_path = json_abs_path.resolve()
|
||||
|
||||
# 3. 初始化返回结果(深拷贝默认模板,避免修改原常量)
|
||||
from copy import deepcopy
|
||||
result = deepcopy(HomePageService.DEFAULT_RETURN_DATA)
|
||||
|
||||
try:
|
||||
# 4. 简化文件存在性判断(合并逻辑,减少分支)
|
||||
if not json_abs_path.exists():
|
||||
result["message"] = f"版本介绍文件不存在:{json_abs_path}"
|
||||
return result
|
||||
|
||||
# 5. 读取并解析 JSON 文件(简化文件操作流程)
|
||||
with open(json_abs_path, "r", encoding="utf-8") as f:
|
||||
changelogs = json.load(f)
|
||||
|
||||
# 6. 简化版本匹配逻辑,直接返回结果或更新提示信息
|
||||
if version in changelogs:
|
||||
return changelogs[version]
|
||||
result["message"] = f"暂未查询到 {version} 版本的详细介绍"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user