Merge branch 'develop' into feature/multimodel_memory
# Conflicts: # api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py # api/app/repositories/neo4j/add_nodes.py # api/app/repositories/neo4j/cypher_queries.py # api/app/repositories/neo4j/graph_saver.py # api/app/services/memory_agent_service.py # api/app/services/multimodal_service.py
This commit is contained in:
@@ -77,6 +77,7 @@ celery_app.conf.update(
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi import APIRouter
|
||||
from . import (
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
app_log_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
@@ -70,6 +71,7 @@ manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
manager_router.include_router(app_controller.router)
|
||||
manager_router.include_router(app_log_controller.router)
|
||||
manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
|
||||
@@ -57,6 +57,7 @@ def list_apps(
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -65,10 +66,25 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- 当提供 api_key 参数时,查找该 API Key 关联的应用
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
|
||||
if api_key:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == api_key,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
ids = str(matched_id) if matched_id else ""
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
|
||||
129
api/app/controllers/app_log_controller.py
Normal file
129
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""应用日志(消息记录)接口"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- 支持按 user_id 筛选
|
||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||
- 按最新更新时间倒序排列
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
total = int(db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(db.scalars(stmt).all())
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={"app_id": str(app_id), "total": total, "page": page}
|
||||
)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_log_detail(
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看某会话的完整消息记录
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
|
||||
# 查询会话(确保属于该应用和工作空间)
|
||||
conversation = db.scalars(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
)
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = list(db.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
).all())
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=detail)
|
||||
@@ -14,6 +14,9 @@ Routes:
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
import httpx
|
||||
import mimetypes
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
@@ -91,7 +94,7 @@ async def upload_file(
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
@@ -172,7 +175,6 @@ async def upload_file_with_share_token(
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
@@ -291,6 +293,101 @@ async def upload_file_with_share_token(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||
async def get_file_info_by_url(
|
||||
url: str,
|
||||
):
|
||||
"""
|
||||
Get file information by network URL (no authentication required).
|
||||
|
||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||
Falls back to GET request if HEAD is not supported.
|
||||
Returns file type, name, and size.
|
||||
|
||||
Args:
|
||||
url: The network URL of the file.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file information.
|
||||
"""
|
||||
api_logger.info(f"File info by URL request: url={url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try HEAD request first
|
||||
response = await client.head(url, follow_redirects=True)
|
||||
|
||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||
if response.status_code != 200:
|
||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
# Get file size from Content-Length header or actual content
|
||||
file_size = response.headers.get("Content-Length")
|
||||
if file_size:
|
||||
file_size = int(file_size)
|
||||
elif hasattr(response, 'content'):
|
||||
file_size = len(response.content)
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
# Get content type from Content-Type header
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
# Remove charset and other parameters from content type
|
||||
content_type = content_type.split(';')[0].strip()
|
||||
|
||||
# Extract filename from Content-Disposition or URL
|
||||
file_name = None
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition and "filename=" in content_disposition:
|
||||
parts = content_disposition.split("filename=")
|
||||
if len(parts) > 1:
|
||||
file_name = parts[1].strip('"').strip("'")
|
||||
|
||||
if not file_name:
|
||||
parsed_url = urlparse(url)
|
||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||
|
||||
# Extract file extension from filename
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
|
||||
# If no extension found, infer from content type
|
||||
if not file_ext:
|
||||
ext = mimetypes.guess_extension(content_type)
|
||||
if ext:
|
||||
file_ext = ext
|
||||
file_name = f"{file_name}{file_ext}"
|
||||
|
||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"file_name": file_name,
|
||||
"file_ext": file_ext.lower() if file_ext else "",
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
},
|
||||
msg="File information retrieved successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file information: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
@@ -499,6 +596,51 @@ async def get_file_url(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||
async def get_permanent_file_url(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
获取文件的永久公开 URL(无过期时间)。
|
||||
|
||||
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||
"""
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if isinstance(storage, LocalStorage):
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||
else:
|
||||
url = await storage.get_permanent_url(file_key)
|
||||
if not url:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Permanent URL not supported for current storage backend")
|
||||
|
||||
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||
return success(
|
||||
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
request: Request,
|
||||
@@ -653,3 +795,44 @@ async def permanent_download_file(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||
async def get_file_status(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get file upload/processing status (no authentication required).
|
||||
|
||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||
Returns status: pending, completed, or failed.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file status and metadata.
|
||||
"""
|
||||
api_logger.info(f"File status request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"file_id": str(file_id),
|
||||
"status": file_metadata.status,
|
||||
"file_name": file_metadata.file_name,
|
||||
"file_size": file_metadata.file_size,
|
||||
"content_type": file_metadata.content_type,
|
||||
},
|
||||
msg="File status retrieved successfully"
|
||||
)
|
||||
|
||||
@@ -195,10 +195,9 @@ async def get_workspace_end_users(
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
||||
try:
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
@@ -33,35 +33,47 @@ def get_memory_count(
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
end_user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
Retrieve conversations for the current user in a specific group with pagination.
|
||||
|
||||
Args:
|
||||
end_user_id (UUID): The group identifier.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
pagesize (int): Number of items per page. Defaults to 20.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
ApiResponse: Contains a paginated list of conversations.
|
||||
"""
|
||||
page = max(1, page)
|
||||
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
end_user_id
|
||||
conversations, total = conversation_service.get_user_conversations(
|
||||
end_user_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
return success(data={
|
||||
"items": [
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
],
|
||||
"total": total,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": page_size,
|
||||
"total": total,
|
||||
"hasnext": (page * page_size) < total
|
||||
},
|
||||
}, msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
|
||||
@@ -76,6 +76,8 @@ async def get_tool_methods(
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -121,6 +123,8 @@ async def create_tool(
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -149,6 +153,8 @@ async def update_tool(
|
||||
return success(msg="工具更新成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -191,6 +197,8 @@ async def set_tool_active(
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -223,6 +231,8 @@ async def execute_tool(
|
||||
},
|
||||
msg="工具执行完成"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ class Settings:
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
|
||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -21,7 +21,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
@@ -176,8 +176,8 @@ async def write(
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(
|
||||
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
||||
await _trigger_clustering_sync(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(
|
||||
|
||||
@@ -237,6 +237,7 @@ class LabelPropagationEngine:
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
return
|
||||
|
||||
# 统计邻居社区分布
|
||||
@@ -271,7 +272,8 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata([target_cid], end_user_id)
|
||||
# 新实体加入后成员变化,强制重新生成元数据
|
||||
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -451,7 +453,7 @@ class LabelPropagationEngine:
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据。
|
||||
@@ -460,69 +462,82 @@ class LabelPropagationEngine:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
"""
|
||||
if not community_ids:
|
||||
return
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||
async def _build_one(cid: str):
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
async def _build_one(cid: str) -> Optional[Dict]:
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||
return None
|
||||
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
|
||||
# 方案四:注入社区内实体间关系三元组
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
name, summary = "", ""
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
@@ -535,15 +550,20 @@ class LabelPropagationEngine:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
if self.embedding_model_id:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
if len(metadata_list) == 1:
|
||||
@@ -556,17 +576,13 @@ class LabelPropagationEngine:
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||
if not result:
|
||||
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if ok:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
return str(uuid.uuid4())
|
||||
@@ -9,6 +9,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
@@ -706,7 +709,7 @@ class SemanticPruner:
|
||||
# 阈值保护:最高0.9
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
if proportion > 0.9:
|
||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
@@ -905,7 +908,7 @@ class SemanticPruner:
|
||||
|
||||
# Safety: avoid empty dataset
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
@@ -915,8 +918,7 @@ class SemanticPruner:
|
||||
try:
|
||||
self.run_logs.append(msg)
|
||||
except Exception:
|
||||
# 任何异常都不影响打印
|
||||
pass
|
||||
print(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
logger.debug("=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
logger.debug("=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
logger.debug("=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
@@ -227,14 +228,13 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' "
|
||||
f"嵌入向量维度: {len(embeddings[i])}")
|
||||
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
@@ -297,7 +297,7 @@ async def embedding_generation_all(
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
|
||||
@@ -121,3 +121,18 @@ class StorageBackend(ABC):
|
||||
URL for accessing the file.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||
"""
|
||||
Get a permanent public URL for the file (no expiration).
|
||||
|
||||
Returns None by default; remote storage backends should override this
|
||||
if the bucket is configured for public read access.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
|
||||
Returns:
|
||||
A permanent public URL, or None if not supported.
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -261,3 +261,13 @@ class OSSStorage(StorageBackend):
|
||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||
# Return a basic URL format as fallback
|
||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> str:
|
||||
"""
|
||||
Get a permanent public URL for the file (requires bucket public read).
|
||||
|
||||
Returns:
|
||||
A permanent URL in the format: https://{bucket}.{endpoint}/{file_key}
|
||||
"""
|
||||
host = self.endpoint.replace("https://", "").replace("http://", "")
|
||||
return f"https://{self.bucket_name}.{host}/{file_key}"
|
||||
|
||||
@@ -378,3 +378,12 @@ class S3Storage(StorageBackend):
|
||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||
# Return a basic URL format as fallback
|
||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||
|
||||
async def get_permanent_url(self, file_key: str) -> str:
|
||||
"""
|
||||
Get a permanent public URL for the file (requires bucket public read).
|
||||
|
||||
Returns:
|
||||
A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key}
|
||||
"""
|
||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
@@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
if not existing_ids:
|
||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||
|
||||
rs = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
|
||||
@@ -506,10 +506,13 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
404: "errors.common.not_found",
|
||||
405: "errors.common.method_not_allowed",
|
||||
409: "errors.common.conflict",
|
||||
413: "errors.common.payload_too_large",
|
||||
422: "errors.common.validation_failed",
|
||||
429: "errors.common.too_many_requests",
|
||||
500: "errors.common.internal_error",
|
||||
502: "errors.common.bad_gateway",
|
||||
503: "errors.common.service_unavailable",
|
||||
504: "errors.common.gateway_timeout",
|
||||
}
|
||||
|
||||
# 如果有对应的翻译键,使用翻译
|
||||
@@ -534,7 +537,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
|
||||
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
username = Column(String, unique=True, index=True, nullable=False)
|
||||
username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
@@ -90,27 +90,27 @@ class ConversationRepository:
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID = None,
|
||||
limit: int = 10,
|
||||
is_activate: bool = True
|
||||
) -> list[Conversation]:
|
||||
is_activate: bool = True,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
Retrieve recent conversations for a specific user.
|
||||
Retrieve recent conversations for a specific user with pagination.
|
||||
|
||||
This method queries conversations associated with the given user ID,
|
||||
optionally scoped to a specific workspace. Results are ordered by the
|
||||
most recently updated conversations and limited to a fixed number.
|
||||
most recently updated conversations.
|
||||
|
||||
Args:
|
||||
user_id (uuid.UUID): Unique identifier of the user.
|
||||
workspace_id (uuid.UUID, optional): Workspace scope for the query.
|
||||
If provided, only conversations under this workspace will be returned.
|
||||
limit (int): Maximum number of conversations to return.
|
||||
Defaults to 10.
|
||||
is_activate (bool): Convsersation State limit
|
||||
is_activate (bool): Conversation State limit.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
page_size (int): Number of items per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
list[Conversation]: A list of conversation entities ordered by
|
||||
last updated time (descending).
|
||||
tuple[list[Conversation], int]: A list of conversation entities and total count.
|
||||
"""
|
||||
logger.info(f"Fetching conversation by user_id: {user_id}")
|
||||
|
||||
@@ -122,18 +122,25 @@ class ConversationRepository:
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.limit(limit)
|
||||
# Calculate total count
|
||||
total = int(self.db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
convsersations = list(self.db.scalars(stmt).all())
|
||||
# Apply ordering and pagination
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
logger.info(
|
||||
"Conversation fetched successfully",
|
||||
extra={
|
||||
"user_id": str(user_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"total": total,
|
||||
}
|
||||
)
|
||||
return convsersations
|
||||
return conversations, total
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
@@ -57,7 +57,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error creating dialogue nodes: {e}")
|
||||
logger.error(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error creating statement nodes: {e}")
|
||||
logger.error(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error creating chunk nodes: {e}")
|
||||
logger.error(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -228,5 +228,5 @@ async def add_memory_summary_nodes(
|
||||
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
return created_ids
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
@@ -24,6 +24,10 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
GET_INCOMPLETE_COMMUNITIES,
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING,
|
||||
CHECK_COMMUNITY_IS_COMPLETE,
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING,
|
||||
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||
)
|
||||
|
||||
@@ -249,6 +253,31 @@ class CommunityRepository:
|
||||
logger.error(f"refresh_member_count failed: {e}")
|
||||
return 0
|
||||
|
||||
async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]:
|
||||
"""查询该用户下属性不完整的 Community 节点 ID 列表。
|
||||
|
||||
Args:
|
||||
end_user_id: 用户 ID
|
||||
check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True)
|
||||
"""
|
||||
try:
|
||||
query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES
|
||||
result = await self.connector.execute_query(query, end_user_id=end_user_id)
|
||||
return [row["community_id"] for row in result]
|
||||
except Exception as e:
|
||||
logger.error(f"get_incomplete_communities failed: {e}")
|
||||
return []
|
||||
|
||||
async def is_community_complete(self, community_id: str, end_user_id: str, check_embedding: bool = False) -> bool:
|
||||
"""检查单个社区节点的属性是否完整。"""
|
||||
try:
|
||||
query = CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING if check_embedding else CHECK_COMMUNITY_IS_COMPLETE
|
||||
result = await self.connector.execute_query(query, community_id=community_id, end_user_id=end_user_id)
|
||||
return result[0]["is_complete"] if result else False
|
||||
except Exception as e:
|
||||
logger.error(f"is_community_complete failed: {e}")
|
||||
return False
|
||||
|
||||
async def update_community_metadata(
|
||||
self,
|
||||
community_id: str,
|
||||
@@ -258,7 +287,7 @@ class CommunityRepository:
|
||||
core_entities: List[str],
|
||||
summary_embedding: Optional[List[float]] = None,
|
||||
) -> bool:
|
||||
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
|
||||
"""更新社区的名称、摘要、核心实体列表及 summary_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
@@ -271,7 +300,7 @@ class CommunityRepository:
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
logger.error(f"update_community_metadata failed: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def batch_update_community_metadata(
|
||||
|
||||
@@ -1075,6 +1075,7 @@ RETURN
|
||||
|
||||
COMMUNITY_NODE_UPSERT = """
|
||||
MERGE (c:Community {community_id: $community_id})
|
||||
ON CREATE SET c.id = $community_id
|
||||
SET c.end_user_id = $end_user_id,
|
||||
c.member_count = $member_count,
|
||||
c.updated_at = datetime()
|
||||
@@ -1181,7 +1182,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
SET c.id = coalesce(c.id, $community_id),
|
||||
c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.summary_embedding = $summary_embedding,
|
||||
@@ -1192,7 +1194,8 @@ RETURN c.community_id AS community_id
|
||||
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||
UNWIND $communities AS row
|
||||
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||
SET c.name = row.name,
|
||||
SET c.id = coalesce(c.id, row.community_id),
|
||||
c.name = row.name,
|
||||
c.summary = row.summary,
|
||||
c.core_entities = row.core_entities,
|
||||
c.summary_embedding = row.summary_embedding,
|
||||
@@ -1276,6 +1279,40 @@ RETURN
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
RETURN (
|
||||
c.name IS NOT NULL AND c.name <> '' AND
|
||||
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||
c.core_entities IS NOT NULL AND
|
||||
c.summary_embedding IS NOT NULL
|
||||
) AS is_complete
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.name = ''
|
||||
OR c.summary IS NULL OR c.summary = ''
|
||||
OR c.core_entities IS NULL
|
||||
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
|
||||
@@ -169,7 +169,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||
schedule_clustering_after_write() 显式触发。
|
||||
_trigger_clustering_sync() 显式触发。
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
@@ -336,16 +336,13 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
async def _trigger_clustering_sync(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
@@ -357,9 +354,9 @@ def schedule_clustering_after_write(
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id))
|
||||
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id)
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
|
||||
@@ -43,6 +43,7 @@ class WorkflowConfigRepository:
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
features: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None
|
||||
) -> WorkflowConfig:
|
||||
"""创建或更新工作流配置
|
||||
@@ -53,6 +54,7 @@ class WorkflowConfigRepository:
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
features: 功能特性
|
||||
triggers: 触发器列表
|
||||
|
||||
Returns:
|
||||
@@ -82,6 +84,7 @@ class WorkflowConfigRepository:
|
||||
edges=edges,
|
||||
variables=variables or [],
|
||||
execution_config=execution_config or {},
|
||||
features=features or {},
|
||||
triggers=triggers or []
|
||||
)
|
||||
self.db.add(config)
|
||||
|
||||
53
api/app/schemas/app_log_schema.py
Normal file
53
api/app/schemas/app_log_schema.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""应用日志(消息记录)Schema"""
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
class AppLogMessage(BaseModel):
|
||||
"""单条消息记录"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
conversation_id: uuid.UUID
|
||||
role: str = Field(description="角色: user / assistant / system")
|
||||
content: str
|
||||
meta_data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("meta_data", when_used="json")
|
||||
def _serialize_meta_data(self, data: Optional[Dict[str, Any]]):
|
||||
return data or {}
|
||||
|
||||
|
||||
class AppLogConversation(BaseModel):
|
||||
"""会话摘要(用于列表)"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
user_id: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
message_count: int = 0
|
||||
is_draft: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class AppLogConversationDetail(AppLogConversation):
|
||||
"""会话详情(包含消息列表)"""
|
||||
messages: List[AppLogMessage] = Field(default_factory=list)
|
||||
@@ -149,18 +149,26 @@ class FileUploadConfig(BaseModel):
|
||||
)
|
||||
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
||||
document_enabled: bool = Field(default=False)
|
||||
document_max_size_mb: int = Field(default=100)
|
||||
document_max_size_mb: int = Field(default=50)
|
||||
document_allowed_extensions: List[str] = Field(
|
||||
default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"]
|
||||
default=["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"]
|
||||
)
|
||||
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
||||
video_enabled: bool = Field(default=False)
|
||||
video_max_size_mb: int = Field(default=500)
|
||||
video_max_size_mb: int = Field(default=50)
|
||||
video_allowed_extensions: List[str] = Field(
|
||||
default=["mp4", "mov"]
|
||||
default=["mp4"]
|
||||
)
|
||||
# 最大文件数量
|
||||
max_file_count: int = Field(default=5, ge=1, le=20)
|
||||
max_file_count: int = Field(default=5, ge=1)
|
||||
|
||||
@field_validator("max_file_count")
|
||||
@classmethod
|
||||
def validate_max_file_count(cls, v: int) -> int:
|
||||
from app.core.config import settings
|
||||
if v > settings.MAX_FILE_COUNT:
|
||||
raise ValueError(f"max_file_count 不能超过 {settings.MAX_FILE_COUNT}")
|
||||
return v
|
||||
|
||||
|
||||
class OpeningStatementConfig(BaseModel):
|
||||
|
||||
@@ -118,28 +118,27 @@ class AppChatService:
|
||||
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=10
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
@@ -180,7 +179,8 @@ class AppChatService:
|
||||
|
||||
# 构建用户消息内容(含多模态文件)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -195,6 +195,13 @@ class AppChatService:
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
@@ -225,6 +232,7 @@ class AppChatService:
|
||||
"suggested_questions": suggested_questions,
|
||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
async def agnet_chat_stream(
|
||||
@@ -313,31 +321,27 @@ class AppChatService:
|
||||
streaming=True
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
@@ -347,8 +351,14 @@ class AppChatService:
|
||||
total_tokens = 0
|
||||
|
||||
text_queue: asyncio.Queue = asyncio.Queue()
|
||||
api_key_config = {
|
||||
"model_name": api_key_obj.model_name,
|
||||
"api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base,
|
||||
"provider": api_key_obj.provider,
|
||||
}
|
||||
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||
features_config, api_key_obj,
|
||||
features_config, api_key_config,
|
||||
text_queue=text_queue,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
@@ -378,7 +388,7 @@ class AppChatService:
|
||||
elapsed_time = time.time() - start_time
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations)
|
||||
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||
@@ -388,11 +398,23 @@ class AppChatService:
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[]
|
||||
"files":[],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -402,11 +424,16 @@ class AppChatService:
|
||||
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models.app_release_model import AppRelease
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.models_model import ModelConfig
|
||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||
from app.models.skill_model import Skill
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
@@ -84,7 +85,9 @@ class AppDslService:
|
||||
if "knowledge_retrieval" in cfg:
|
||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||
if "tools" in cfg:
|
||||
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
||||
enriched["tools"] = self._enrich_tools(cfg.get("tools"))
|
||||
if "skills" in cfg:
|
||||
enriched["skills"] = self._enrich_skills(cfg.get("skills"))
|
||||
return enriched
|
||||
if app_type == AppType.MULTI_AGENT:
|
||||
enriched = {**cfg}
|
||||
@@ -108,6 +111,7 @@ class AppDslService:
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"features": config.features if config else {},
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
} if config else {}
|
||||
@@ -123,7 +127,8 @@ class AppDslService:
|
||||
"memory": config.memory if config else None,
|
||||
"variables": config.variables if config else [],
|
||||
"tools": self._enrich_tools(config.tools) if config else [],
|
||||
"skills": config.skills if config else {},
|
||||
"skills": self._enrich_skills(config.skills) if config else {},
|
||||
"features": config.features if config else {}
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||
|
||||
@@ -185,6 +190,22 @@ class AppDslService:
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||
if not skill_id:
|
||||
return None
|
||||
s = self.db.query(Skill).filter(Skill.id == skill_id).first()
|
||||
return {"id": str(skill_id), "name": s.name} if s else {"id": str(skill_id)}
|
||||
|
||||
def _enrich_skills(self, skills: Optional[dict]) -> Optional[dict]:
|
||||
if not skills:
|
||||
return skills
|
||||
skill_ids = skills.get("skill_ids", [])
|
||||
enriched_ids = [
|
||||
{"id": sid, "_ref": self._skill_ref(sid)}
|
||||
for sid in (skill_ids or [])
|
||||
]
|
||||
return {**skills, "skill_ids": enriched_ids}
|
||||
|
||||
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||
if not agent_id:
|
||||
return None
|
||||
@@ -249,7 +270,8 @@ class AppDslService:
|
||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||
variables=cfg.get("variables", []),
|
||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||
skills=cfg.get("skills", {}),
|
||||
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
|
||||
features=cfg.get("features", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -290,6 +312,7 @@ class AppDslService:
|
||||
edges=[e.model_dump() for e in result.edges],
|
||||
variables=[v.model_dump() for v in result.variables],
|
||||
execution_config=wf.get("execution_config", {}),
|
||||
features=wf.get("features", {}),
|
||||
triggers=wf.get("triggers", []),
|
||||
validate=False,
|
||||
)
|
||||
@@ -444,6 +467,46 @@ class AppDslService:
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
|
||||
def _resolve_skills(self, skills: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> dict:
|
||||
if not skills:
|
||||
return skills or {}
|
||||
resolved_ids = []
|
||||
for entry in (skills.get("skill_ids") or []):
|
||||
# entry 可能是 {"id": "...", "_ref": {...}} 或直接是字符串
|
||||
if isinstance(entry, dict):
|
||||
ref = entry.get("_ref") or ({"name": None, "id": entry.get("id")} if entry.get("id") else None)
|
||||
skill_id = self._resolve_skill(ref, tenant_id, warnings)
|
||||
else:
|
||||
skill_id = self._resolve_skill({"id": str(entry)}, tenant_id, warnings)
|
||||
if skill_id:
|
||||
resolved_ids.append(str(skill_id))
|
||||
return {**{k: v for k, v in skills.items() if k != "skill_ids"}, "skill_ids": resolved_ids}
|
||||
|
||||
def _resolve_skill(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
# 先按 id 匹配
|
||||
if ref.get("id"):
|
||||
try:
|
||||
s = self.db.query(Skill).filter(
|
||||
Skill.id == uuid.UUID(str(ref["id"])),
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
if s:
|
||||
return str(s.id)
|
||||
except Exception:
|
||||
pass
|
||||
# 再按名称匹配
|
||||
if ref.get("name"):
|
||||
s = self.db.query(Skill).filter(
|
||||
Skill.name == ref["name"],
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
if s:
|
||||
return str(s.id)
|
||||
warnings.append(f"未找到技能: {ref}")
|
||||
return None
|
||||
|
||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||
result = []
|
||||
for t in (tools or []):
|
||||
|
||||
@@ -833,8 +833,6 @@ class AppService:
|
||||
|
||||
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
|
||||
target_tenant_id = None
|
||||
available_model_ids: set = set()
|
||||
available_kb_ids: set = set()
|
||||
if is_cross_workspace:
|
||||
target_ws = self.db.get(Workspace, target_workspace_id)
|
||||
if not target_ws:
|
||||
@@ -849,28 +847,29 @@ class AppService:
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
# Batch-collect and preload all referenced resources
|
||||
model_ids, kb_ids = self._collect_resource_ids_from_config(
|
||||
source_config.default_model_config_id,
|
||||
source_config.knowledge_retrieval,
|
||||
source_config.tools
|
||||
# 跨工作空间:model/tools/skills 属于 tenant 级别直接保留,
|
||||
# knowledge_bases 属于 workspace 级别需过滤,memory_config 需清空
|
||||
_, kb_ids = self._collect_resource_ids_from_config(
|
||||
None, source_config.knowledge_retrieval
|
||||
)
|
||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||
)
|
||||
new_model_config_id = self._is_model_available(
|
||||
source_config.default_model_config_id, available_model_ids
|
||||
_, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, set(), kb_ids
|
||||
)
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
new_knowledge_retrieval = self._clean_knowledge_retrieval(
|
||||
source_config.knowledge_retrieval, available_kb_ids
|
||||
)
|
||||
new_tools = self._clean_tools(
|
||||
source_config.tools, available_kb_ids
|
||||
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||
new_memory = self._clean_memory_cross_workspace(
|
||||
source_config.memory, target_workspace_id
|
||||
)
|
||||
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||
else:
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
|
||||
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||
new_memory = copy.deepcopy(source_config.memory) if source_config.memory else None
|
||||
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||
|
||||
new_config = AgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
@@ -879,9 +878,11 @@ class AppService:
|
||||
default_model_config_id=new_model_config_id,
|
||||
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||
knowledge_retrieval=new_knowledge_retrieval,
|
||||
memory=copy.deepcopy(source_config.memory) if source_config.memory else None,
|
||||
memory=new_memory,
|
||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||
tools=new_tools,
|
||||
skills=new_skills,
|
||||
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -894,28 +895,14 @@ class AppService:
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes(
|
||||
source_config.nodes
|
||||
)
|
||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||
)
|
||||
new_nodes = self._clean_workflow_nodes_for_cross_workspace(
|
||||
source_config.nodes or [],
|
||||
available_model_ids,
|
||||
available_kb_ids
|
||||
)
|
||||
else:
|
||||
new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else []
|
||||
|
||||
new_config = WorkflowConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
nodes=new_nodes,
|
||||
nodes=copy.deepcopy(source_config.nodes) if source_config.nodes else [],
|
||||
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
|
||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
||||
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
@@ -929,24 +916,15 @@ class AppService:
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set()
|
||||
available_model_ids, _ = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, set()
|
||||
)
|
||||
new_model_config_id = self._is_model_available(
|
||||
source_config.default_model_config_id, available_model_ids
|
||||
)
|
||||
else:
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
|
||||
# multi_agent 的 model_config_id/sub_agents/routing_rules 均属于 tenant 级别直接保留
|
||||
# 跨空间时 master_agent_id(AppRelease)属于源空间,需清空
|
||||
new_config = MultiAgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
|
||||
master_agent_name=source_config.master_agent_name,
|
||||
default_model_config_id=new_model_config_id,
|
||||
model_parameters=source_config.model_parameters,
|
||||
default_model_config_id=source_config.default_model_config_id,
|
||||
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||
orchestration_mode=source_config.orchestration_mode,
|
||||
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
|
||||
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
|
||||
@@ -1037,8 +1015,7 @@ class AppService:
|
||||
@staticmethod
|
||||
def _collect_resource_ids_from_config(
|
||||
model_config_id: Optional[uuid.UUID],
|
||||
knowledge_retrieval: Optional[dict],
|
||||
tools: Optional[list]
|
||||
knowledge_retrieval: Optional[dict]
|
||||
) -> tuple:
|
||||
"""Extract all model config IDs and knowledge base IDs from an app config."""
|
||||
model_ids: set = set()
|
||||
@@ -1048,62 +1025,12 @@ class AppService:
|
||||
model_ids.add(model_config_id)
|
||||
|
||||
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
|
||||
if "kb_ids" in knowledge_retrieval:
|
||||
for kid in knowledge_retrieval.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if knowledge_retrieval.get("knowledge_id"):
|
||||
kb_ids.add(str(knowledge_retrieval["knowledge_id"]))
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
kid = tool.get("knowledge_id") or tool.get("kb_id")
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if "knowledge_bases" in knowledge_retrieval:
|
||||
for kid in knowledge_retrieval.get("knowledge_bases", []):
|
||||
kb_ids.add(str(kid.get("kb_id")))
|
||||
|
||||
return model_ids, kb_ids
|
||||
|
||||
@staticmethod
|
||||
def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple:
|
||||
"""Extract all model config IDs and knowledge base IDs from workflow nodes."""
|
||||
model_ids: set = set()
|
||||
kb_ids: set = set()
|
||||
|
||||
for node in (nodes or []):
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for key in ("model_config_id", "default_model_config_id"):
|
||||
val = data.get(key)
|
||||
if val:
|
||||
try:
|
||||
model_ids.add(uuid.UUID(str(val)))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
kr = data.get("knowledge_retrieval")
|
||||
if isinstance(kr, dict):
|
||||
for kid in kr.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if kr.get("knowledge_id"):
|
||||
kb_ids.add(str(kr["knowledge_id"]))
|
||||
if data.get("knowledge_id"):
|
||||
kb_ids.add(str(data["knowledge_id"]))
|
||||
for kid in data.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
|
||||
return model_ids, kb_ids
|
||||
|
||||
@staticmethod
|
||||
def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]:
|
||||
if not model_config_id:
|
||||
return None
|
||||
return model_config_id if model_config_id in available_model_ids else None
|
||||
|
||||
@staticmethod
|
||||
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
|
||||
if not kb_id:
|
||||
@@ -1124,95 +1051,53 @@ class AppService:
|
||||
|
||||
cleaned = copy.deepcopy(knowledge_retrieval)
|
||||
|
||||
if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list):
|
||||
cleaned["kb_ids"] = [
|
||||
kid for kid in cleaned["kb_ids"]
|
||||
if self._is_kb_available(kid, available_kb_ids)
|
||||
if "knowledge_bases" in cleaned and isinstance(cleaned["knowledge_bases"], list):
|
||||
cleaned["knowledge_bases"] = [
|
||||
kb for kb in cleaned["knowledge_bases"]
|
||||
if self._is_kb_available(kb.get("kb_id"), available_kb_ids)
|
||||
]
|
||||
|
||||
if "knowledge_id" in cleaned:
|
||||
cleaned["knowledge_id"] = self._is_kb_available(
|
||||
cleaned.get("knowledge_id"), available_kb_ids
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools(
|
||||
def _clean_memory_cross_workspace(
|
||||
self,
|
||||
tools: Optional[list],
|
||||
available_kb_ids: set
|
||||
) -> list:
|
||||
"""Clean tools config, keeping built-in tools and tools with available KBs."""
|
||||
if not tools:
|
||||
return []
|
||||
memory: Optional[dict],
|
||||
target_workspace_id: uuid.UUID
|
||||
) -> Optional[dict]:
|
||||
"""Clear memory_config_id/memory_content if it doesn't belong to target workspace."""
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
cleaned = []
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
cleaned.append(tool)
|
||||
continue
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in ("builtin", "built_in", "system"):
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
continue
|
||||
cleaned = copy.deepcopy(memory)
|
||||
# 兼容旧字段 memory_content 和新字段 memory_config_id
|
||||
mid = cleaned.get("memory_config_id") or cleaned.get("memory_content")
|
||||
if mid:
|
||||
try:
|
||||
mid_uuid = uuid.UUID(str(mid))
|
||||
except (ValueError, AttributeError):
|
||||
exists = self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == int(mid),
|
||||
MemoryConfig.workspace_id == target_workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
return cleaned
|
||||
|
||||
kb_id = tool.get("knowledge_id") or tool.get("kb_id")
|
||||
if kb_id:
|
||||
if self._is_kb_available(kb_id, available_kb_ids):
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
continue
|
||||
exists = self.db.query(
|
||||
self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id == mid_uuid,
|
||||
MemoryConfig.workspace_id == target_workspace_id
|
||||
).exists()
|
||||
).scalar()
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_workflow_nodes_for_cross_workspace(
|
||||
self,
|
||||
nodes: list,
|
||||
available_model_ids: set,
|
||||
available_kb_ids: set
|
||||
) -> list:
|
||||
"""Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source."""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
cleaned = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
cleaned.append(node)
|
||||
continue
|
||||
|
||||
node_copy = copy.deepcopy(node)
|
||||
data = node_copy.get("data")
|
||||
if not isinstance(data, dict):
|
||||
cleaned.append(node_copy)
|
||||
continue
|
||||
|
||||
for key in ("model_config_id", "default_model_config_id"):
|
||||
if key in data and data[key]:
|
||||
try:
|
||||
mid = uuid.UUID(str(data[key]))
|
||||
except (ValueError, AttributeError):
|
||||
data[key] = None
|
||||
continue
|
||||
data[key] = str(mid) if mid in available_model_ids else None
|
||||
|
||||
if "knowledge_retrieval" in data and data["knowledge_retrieval"]:
|
||||
data["knowledge_retrieval"] = self._clean_knowledge_retrieval(
|
||||
data["knowledge_retrieval"], available_kb_ids
|
||||
)
|
||||
if "knowledge_id" in data:
|
||||
data["knowledge_id"] = self._is_kb_available(
|
||||
data.get("knowledge_id"), available_kb_ids
|
||||
)
|
||||
if "kb_ids" in data and isinstance(data["kb_ids"], list):
|
||||
data["kb_ids"] = [
|
||||
kid for kid in data["kb_ids"]
|
||||
if self._is_kb_available(kid, available_kb_ids)
|
||||
]
|
||||
|
||||
cleaned.append(node_copy)
|
||||
return cleaned
|
||||
|
||||
def list_apps(
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.models.conversation_model import ConversationDetail
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
from app.schemas.conversation_schema import ConversationOut
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services import workspace_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -119,25 +120,27 @@ class ConversationService:
|
||||
|
||||
def get_user_conversations(
|
||||
self,
|
||||
user_id: uuid.UUID
|
||||
) -> list[Conversation]:
|
||||
user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
Retrieve recent conversations for a specific user
|
||||
|
||||
This method delegates persistence logic to the repository layer and
|
||||
applies service-level defaults (e.g. recent conversation limit).
|
||||
Retrieve recent conversations for a specific user with pagination.
|
||||
|
||||
Args:
|
||||
user_id (uuid.UUID): Unique identifier of the user.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
page_size (int): Number of items per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
list[Conversation]: A list of recent conversation entities.
|
||||
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
|
||||
"""
|
||||
conversations = self.conversation_repo.get_conversation_by_user_id(
|
||||
conversations, total = self.conversation_repo.get_conversation_by_user_id(
|
||||
user_id,
|
||||
limit=10
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return conversations
|
||||
return conversations, total
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
@@ -267,10 +270,12 @@ class ConversationService:
|
||||
|
||||
return messages
|
||||
|
||||
def get_conversation_history(
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
max_history: Optional[int] = None,
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve historical conversation messages formatted as dictionaries.
|
||||
@@ -278,6 +283,8 @@ class ConversationService:
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||
current_provider (Optional[str]): Current provider for file handling.
|
||||
current_is_omni (Optional[bool]): Current omni flag for file handling.
|
||||
|
||||
Returns:
|
||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||
@@ -287,14 +294,30 @@ class ConversationService:
|
||||
limit=max_history
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
history = [
|
||||
{
|
||||
history = []
|
||||
for msg in messages:
|
||||
msg_dict = {
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
"content": [{"type": "text", "text": msg.content}]
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理用户消息中的多模态文件
|
||||
if msg.role == "user" and msg.meta_data:
|
||||
history_files = msg.meta_data.get("history_files", {})
|
||||
|
||||
if history_files and current_provider and current_is_omni is not None:
|
||||
# 检查是否需要重新处理文件
|
||||
stored_provider = history_files.get("provider")
|
||||
stored_is_omni = history_files.get("is_omni")
|
||||
|
||||
# 如果provider或is_omni不匹配,需要重新处理
|
||||
if stored_provider != current_provider or stored_is_omni != current_is_omni:
|
||||
continue
|
||||
|
||||
# provider和is_omni匹配,直接使用存储的内容
|
||||
msg_dict["content"].extend(history_files.get("content"))
|
||||
|
||||
history.append(msg_dict)
|
||||
|
||||
return history
|
||||
|
||||
@@ -510,6 +533,7 @@ class ConversationService:
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -517,14 +541,17 @@ class ConversationService:
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
conversation_messages = self.get_conversation_history(
|
||||
conversation_messages = await self.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=20
|
||||
max_history=20,
|
||||
current_provider=provider,
|
||||
current_is_omni=is_omni
|
||||
)
|
||||
if len(conversation_messages) == 0:
|
||||
return ConversationOut(
|
||||
|
||||
@@ -579,25 +579,28 @@ class AgentRunService:
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=model_config.type
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10
|
||||
max_history=10,
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
@@ -659,7 +662,10 @@ class AgentRunService:
|
||||
})
|
||||
},
|
||||
files=files,
|
||||
audio_url=audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
response = {
|
||||
@@ -676,6 +682,7 @@ class AgentRunService:
|
||||
) if not sub_agent else [],
|
||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
logger.info(
|
||||
@@ -815,25 +822,28 @@ class AgentRunService:
|
||||
sub_agent=sub_agent
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=model_config.type
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10),
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
@@ -905,10 +915,13 @@ class AgentRunService:
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
},
|
||||
files=files,
|
||||
audio_url=stream_audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=stream_audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 12. 发送结束事件(包含 suggested_questions 和 tts)
|
||||
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
|
||||
end_data: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -919,6 +932,17 @@ class AgentRunService:
|
||||
features_config, full_content, api_key_config, effective_params
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self._filter_citations(features_config, [])
|
||||
yield self._format_sse_event("end", end_data)
|
||||
|
||||
@@ -1115,13 +1139,17 @@ class AgentRunService:
|
||||
async def _load_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
max_history: int = 10
|
||||
max_history: int = 10,
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""加载会话历史消息
|
||||
"""加载会话历史消息,并根据当前模型配置处理多模态文件
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
current_provider: 当前模型的provider
|
||||
current_is_omni: 当前模型的is_omni
|
||||
|
||||
Returns:
|
||||
List[Dict]: 历史消息列表
|
||||
@@ -1129,9 +1157,12 @@ class AgentRunService:
|
||||
try:
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
history = conversation_service.get_conversation_history(
|
||||
# 获取 API 配置用于多模态处理
|
||||
history = await conversation_service.get_conversation_history(
|
||||
conversation_id=uuid.UUID(conversation_id),
|
||||
max_history=max_history
|
||||
max_history=max_history,
|
||||
current_provider=current_provider,
|
||||
current_is_omni=current_is_omni
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1159,7 +1190,10 @@ class AgentRunService:
|
||||
app_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None,
|
||||
audio_url: Optional[str] = None
|
||||
processed_files: Optional[List[Dict[str, Any]]] = None,
|
||||
audio_url: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
is_omni: Optional[bool] = None
|
||||
) -> None:
|
||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||
|
||||
@@ -1170,6 +1204,11 @@ class AgentRunService:
|
||||
app_id: 应用ID(未使用,保留用于兼容性)
|
||||
user_id: 用户ID(未使用,保留用于兼容性)
|
||||
meta_data: token消耗
|
||||
files: 原始文件输入
|
||||
processed_files: 处理后的文件
|
||||
audio_url: 音频URL
|
||||
provider: 模型供应商
|
||||
is_omni: 是否为全模态模型
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -1179,15 +1218,24 @@ class AgentRunService:
|
||||
|
||||
# 保存消息(会话已经存在)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
# 保存 history_files,包含 provider 和 is_omni 信息
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": provider,
|
||||
"is_omni": is_omni
|
||||
}
|
||||
|
||||
# 保存用户消息
|
||||
conversation_service.add_message(
|
||||
conversation_id=conv_uuid,
|
||||
@@ -1413,8 +1461,9 @@ class AgentRunService:
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||
"""文本流式输入并行合成音频。
|
||||
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
||||
返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
|
||||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
|
||||
"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
@@ -1801,6 +1850,7 @@ class AgentRunService:
|
||||
),
|
||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||
"audio_url": result.get("audio_url"),
|
||||
"audio_status": result.get("audio_status"),
|
||||
"citations": result.get("citations", []),
|
||||
"suggested_questions": result.get("suggested_questions", []),
|
||||
"error": None
|
||||
@@ -1878,6 +1928,7 @@ class AgentRunService:
|
||||
"results": [{
|
||||
**r,
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
} for r in results],
|
||||
@@ -2009,6 +2060,7 @@ class AgentRunService:
|
||||
full_content = ""
|
||||
returned_conversation_id = model_conversation_id
|
||||
audio_url = None
|
||||
audio_status = None
|
||||
citations = []
|
||||
suggested_questions = []
|
||||
|
||||
@@ -2067,6 +2119,7 @@ class AgentRunService:
|
||||
# 从 end 事件中提取 features 输出字段
|
||||
if event_type == "end" and event_data:
|
||||
audio_url = event_data.get("audio_url")
|
||||
audio_status = event_data.get("audio_status")
|
||||
citations = event_data.get("citations", [])
|
||||
suggested_questions = event_data.get("suggested_questions", [])
|
||||
|
||||
@@ -2096,6 +2149,7 @@ class AgentRunService:
|
||||
"message": full_content,
|
||||
"elapsed_time": elapsed,
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"error": None
|
||||
@@ -2110,6 +2164,7 @@ class AgentRunService:
|
||||
"elapsed_time": elapsed,
|
||||
"message_length": len(full_content),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"timestamp": time.time()
|
||||
@@ -2246,6 +2301,7 @@ class AgentRunService:
|
||||
"message": r.get("message"),
|
||||
"elapsed_time": r.get("elapsed_time", 0),
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
"error": r.get("error")
|
||||
|
||||
@@ -619,7 +619,7 @@ class MemoryForgetService:
|
||||
recent_trends.append({
|
||||
'date': date_str,
|
||||
'merged_count': record.merged_count,
|
||||
'average_activation': record.average_activation_value,
|
||||
'average_activation': round(record.average_activation_value, 2) if record.average_activation_value is not None else None,
|
||||
'total_nodes': record.total_nodes,
|
||||
'execution_time': int(record.execution_time.timestamp() * 1000)
|
||||
})
|
||||
|
||||
@@ -12,10 +12,12 @@ import base64
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import PyPDF2
|
||||
import chardet
|
||||
import httpx
|
||||
import magic
|
||||
import openpyxl
|
||||
@@ -39,12 +41,10 @@ PDF_MIME = ['application/pdf']
|
||||
DOC_MIME = [
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/zip'
|
||||
]
|
||||
XLSX_MIME = [
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-excel',
|
||||
'application/zip'
|
||||
]
|
||||
CSV_MIME = ['text/csv', 'application/csv']
|
||||
JSON_MIME = ['application/json']
|
||||
@@ -402,6 +402,71 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def history_process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
if self.provider == "dashscope" and self.is_omni:
|
||||
strategy_class = OpenAIFormatStrategy
|
||||
else:
|
||||
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||
if not strategy_class:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
is_support, content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
is_support, content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理文件失败",
|
||||
extra={
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -561,12 +626,12 @@ class MultimodalService:
|
||||
file.set_content(file_content)
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in TEXT_MIME:
|
||||
return file_content.decode("utf-8")
|
||||
return self._decode_text_safe(file_content)
|
||||
elif file_mime_type in PDF_MIME:
|
||||
return await self._extract_pdf_text(file_content)
|
||||
elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')):
|
||||
elif self._is_word_file(file_content, file_mime_type):
|
||||
return await self._extract_word_text(file_content)
|
||||
elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")):
|
||||
elif self._is_excel_file(file_content, file_mime_type):
|
||||
return await self._extract_xlsx_text(file_content)
|
||||
elif file_mime_type in CSV_MIME:
|
||||
return await self._extract_csv_text(file_content)
|
||||
@@ -595,52 +660,156 @@ class MultimodalService:
|
||||
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_content: bytes) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
"""提取 Word 文档文本(支持 .docx 和旧版 .doc)"""
|
||||
# 先尝试 docx(ZIP 格式)
|
||||
if file_content[:2] == b'PK':
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
return '\n'.join(p.text for p in doc.paragraphs)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 docx 文本失败: {e}")
|
||||
return f"[docx 提取失败: {str(e)}]"
|
||||
|
||||
# 旧版 .doc(OLE2 格式)
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||
if not ole.exists('WordDocument'):
|
||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||
stream = ole.openstream('WordDocument').read()
|
||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||
# 简单提取:过滤出可打印字符段
|
||||
try:
|
||||
text = stream.decode('utf-16-le', errors='ignore')
|
||||
except Exception:
|
||||
text = stream.decode('latin-1', errors='ignore')
|
||||
# 过滤控制字符,保留可打印内容
|
||||
import re
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
text = re.sub(r' +', ' ', text).strip()
|
||||
ole.close()
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
logger.error(f"提取 doc 文本失败: {e}")
|
||||
return f"[doc 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_xlsx_text(file_content: bytes) -> str:
|
||||
"""提取 Excel 文本"""
|
||||
"""提取 Excel 文本(支持 .xlsx 和旧版 .xls)"""
|
||||
# xlsx(ZIP 格式)
|
||||
if file_content[:2] == b'PK':
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||
parts = []
|
||||
for sheet in wb.worksheets:
|
||||
parts.append(f"[Sheet: {sheet.title}]")
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||
return '\n'.join(parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 xlsx 文本失败: {e}")
|
||||
return f"[xlsx 提取失败: {str(e)}]"
|
||||
|
||||
# xls(OLE2/BIFF 格式)
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||
import xlrd
|
||||
wb = xlrd.open_workbook(file_contents=file_content)
|
||||
parts = []
|
||||
for sheet in wb.worksheets:
|
||||
parts.append(f"[Sheet: {sheet.title}]")
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||
for sheet in wb.sheets():
|
||||
parts.append(f"[Sheet: {sheet.name}]")
|
||||
for row_idx in range(sheet.nrows):
|
||||
parts.append('\t'.join(str(sheet.cell_value(row_idx, col)) for col in range(sheet.ncols)))
|
||||
return '\n'.join(parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Excel 文本失败: {e}")
|
||||
return f"[Excel 提取失败: {str(e)}]"
|
||||
logger.error(f"提取 xls 文本失败: {e}")
|
||||
return f"[xls 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_csv_text(file_content: bytes) -> str:
|
||||
async def _extract_csv_text(self, file_content: bytes) -> str:
|
||||
"""提取 CSV 文本"""
|
||||
try:
|
||||
text = file_content.decode('utf-8-sig')
|
||||
text = self._decode_text_safe(file_content)
|
||||
reader = csv.reader(io.StringIO(text))
|
||||
return '\n'.join('\t'.join(row) for row in reader)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 CSV 文本失败: {e}")
|
||||
return f"[CSV 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_json_text(file_content: bytes) -> str:
|
||||
async def _extract_json_text(self, file_content: bytes) -> str:
|
||||
"""提取 JSON 文本"""
|
||||
try:
|
||||
data = json.loads(file_content.decode('utf-8'))
|
||||
text = self._decode_text_safe(file_content)
|
||||
data = json.loads(text)
|
||||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 JSON 文本失败: {e}")
|
||||
return f"[JSON 提取失败: {str(e)}]"
|
||||
|
||||
def _is_word_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||
"""判断是不是 Word 文件(doc / docx),不依赖后缀"""
|
||||
# 旧版 .doc
|
||||
if mime_type == 'application/msword':
|
||||
return True
|
||||
|
||||
# 新版 .docx(ZIP 内部包含 word/document.xml)
|
||||
header = file_content[:4]
|
||||
if header == b'PK\x03\x04':
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
return "word/document.xml" in zf.namelist()
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _is_excel_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||
"""判断是不是 Excel 文件(xls / xlsx),不依赖后缀"""
|
||||
# 旧版 .xls
|
||||
if mime_type == 'application/vnd.ms-excel':
|
||||
return True
|
||||
|
||||
# 新版 .xlsx(ZIP 内部包含 xl/workbook.xml)
|
||||
header = file_content[:4]
|
||||
if header == b'PK\x03\x04':
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
return "xl/workbook.xml" in zf.namelist()
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _decode_text_safe(file_content: bytes) -> str:
|
||||
"""
|
||||
【万能文本解码】
|
||||
自动检测编码,支持 utf-8 / gbk / gb2312 / utf-8-sig / ascii 等
|
||||
永远不报错,永远不乱码
|
||||
"""
|
||||
if not file_content:
|
||||
return ""
|
||||
|
||||
# 1. 自动检测文件编码
|
||||
detect = chardet.detect(file_content)
|
||||
encoding = detect.get("encoding") or "utf-8"
|
||||
encoding = encoding.lower()
|
||||
|
||||
# 2. 兼容常见中文编码
|
||||
compatible_encodings = ["utf-8", "gbk", "gb18030", "gb2312", "ascii", "latin-1"]
|
||||
|
||||
# 3. 按优先级尝试解码
|
||||
for enc in [encoding] + compatible_encodings:
|
||||
if not enc:
|
||||
continue
|
||||
try:
|
||||
return file_content.decode(enc.strip())
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
# 终极兜底
|
||||
return file_content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||
"""获取多模态服务实例(依赖注入)"""
|
||||
|
||||
@@ -1408,12 +1408,11 @@ async def analytics_memory_types(
|
||||
if end_user_id:
|
||||
try:
|
||||
conversation_repo = ConversationRepository(db)
|
||||
conversations = conversation_repo.get_conversation_by_user_id(
|
||||
conversations, total = conversation_repo.get_conversation_by_user_id(
|
||||
user_id=uuid.UUID(end_user_id),
|
||||
limit=100, # 获取更多会话以准确统计
|
||||
is_activate=True
|
||||
)
|
||||
work_count = len(conversations)
|
||||
work_count = total
|
||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||
|
||||
@@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User:
|
||||
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={"username": user.username, "email": user.email}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
# 检查邮箱是否已注册(邮箱保持唯一)
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
@@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
|
||||
)
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"created_by": str(current_user.id)
|
||||
}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
# 检查邮箱是否已注册(邮箱保持唯一)
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
|
||||
@@ -57,6 +57,7 @@ class WorkflowService:
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
features: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
@@ -68,6 +69,7 @@ class WorkflowService:
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
features: 功能特性
|
||||
triggers: 触发器列表
|
||||
validate: 是否验证配置
|
||||
|
||||
@@ -83,6 +85,7 @@ class WorkflowService:
|
||||
"edges": edges,
|
||||
"variables": variables or [],
|
||||
"execution_config": execution_config or {},
|
||||
"features": features or {},
|
||||
"triggers": triggers or []
|
||||
}
|
||||
|
||||
@@ -103,6 +106,7 @@ class WorkflowService:
|
||||
edges=edges,
|
||||
variables=variables,
|
||||
execution_config=execution_config,
|
||||
features=features,
|
||||
triggers=triggers
|
||||
)
|
||||
|
||||
|
||||
@@ -2625,13 +2625,15 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||
任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户 ID 列表
|
||||
workspace_id: 工作空间 ID,用于完成标记
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
@@ -2657,6 +2659,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
|
||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||
user_llm_map: Dict[str, Optional[str]] = {}
|
||||
user_embedding_map: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
@@ -2668,21 +2671,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
try:
|
||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}")
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
else:
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||
logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}")
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
try:
|
||||
# 已有社区节点则跳过
|
||||
# 已有社区节点时,检查是否存在属性不完整的节点
|
||||
has_communities = await repo.has_communities(end_user_id)
|
||||
if has_communities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
incomplete_ids = await repo.get_incomplete_communities(
|
||||
end_user_id, check_embedding=bool(embedding_model_id)
|
||||
)
|
||||
if not incomplete_ids:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过")
|
||||
continue
|
||||
|
||||
# 对不完整的社区节点逐一补全元数据
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全"
|
||||
)
|
||||
patch_ok = 0
|
||||
patch_fail = 0
|
||||
for cid in incomplete_ids:
|
||||
try:
|
||||
await engine._generate_community_metadata([cid], end_user_id)
|
||||
patch_ok += 1
|
||||
except Exception as patch_err:
|
||||
patch_fail += 1
|
||||
logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}")
|
||||
logger.info(
|
||||
f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}"
|
||||
)
|
||||
initialized += 1
|
||||
continue
|
||||
|
||||
# 检查是否有 ExtractedEntity 节点
|
||||
@@ -2692,11 +2728,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||
continue
|
||||
|
||||
# 每个用户使用自己的 llm_model_id
|
||||
# 每个用户使用自己的 llm_model_id / embedding_model_id
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -1,4 +1,38 @@
|
||||
{
|
||||
"v0.2.8": {
|
||||
"introduction": {
|
||||
"codeName": "景玉",
|
||||
"releaseDate": "2026-3-20",
|
||||
"upgradePosition": "🐻 MemoryBear v0.2.8 社区版全面升级应用共享、多模态交互与平台基础设施,引入语音交互、感知记忆和云端存储,打造更强大的开放 AI 记忆平台",
|
||||
"coreUpgrades": [
|
||||
"1. 应用共享与发布<br>* 应用共享(Agent、工作流、Agent 集群):全类型应用共享至其他空间<br>* 分享应用默认开启记忆功能:发布分享后记忆默认开启,关闭时提醒<br>* 工作流记忆分享规则:按记忆配置自动控制分享页记忆开关<br>* 分享会话联网搜索修复:恢复分享应用的联网搜索能力",
|
||||
"2. 多模态与交互 💬<br>* 语音输入:模型接口和应用支持语音输入<br>* 语音回复:应用支持语音回复模态<br>* 多模态感知记忆:记忆系统支持视觉、音频、图片和文件的感知记忆<br>* 对话框文件展示:试运行和体验分享中正确展示上传文件",
|
||||
"3. 平台与基础设施 ⚙️<br>* i18n 国际化:全面多语言多地区支持<br>* 云端文件存储(OSS + S3):支持阿里云 OSS 和 S3 云端上传<br>* Flower 容器监控:Celery 异步任务监控与管理",
|
||||
"4. EndUser 身份迁移 🔐<br>* EndUser 从 app_id 迁移至 workspace_id:身份从应用级迁移至工作空间级",
|
||||
"5. 情景记忆 🧠<br>* 情景记忆聚类算法:基于社区图谱的聚类算法,支持老用户图谱生成",
|
||||
"6. 稳健性与缺陷修复 🔧<br>* MCP 服务删除后工具 404:修复删除 MCP 服务后接口报错<br>* 应用导出配置不一致:导出已保存配置而非画布状态<br>* 工作流节点 ID 重复:修复复制节点后 ID 冲突<br>* 条件分支连线错误:修复保存刷新后连线错乱<br>* 回复节点内容丢失:修复点击画布后内容消失<br>* 连接桩规则优化:禁止非法连接方向<br>* 知识库状态列宽度:锁定或自适应宽度<br>* 等待中文档预览:支持未完成解析文档预览<br>* 知识库关联修复:统一修复关联问题<br>* 多模态对话连续性:修复多模态内容后无法继续对话<br>* 时区统一:环境变量统一控制存储和任务时区<br>* 遗忘强度精度:修复小数显示过长",
|
||||
"<br>",
|
||||
"v0.2.8 社区版在应用共享和多模态交互方面实现重大升级,感知记忆扩展了平台的认知维度。后续将深化多智能体协作、情景记忆聚类,并持续优化平台稳定性与开放生态。",
|
||||
"MemoryBear —— 让 AI 拥有记忆 🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "JingYu",
|
||||
"releaseDate": "2026-3-20",
|
||||
"upgradePosition": "🐻 MemoryBear v0.2.8 Community delivers multimodal interaction, perceptual memory, cloud storage, and workspace-level identity for a more capable open AI memory platform",
|
||||
"coreUpgrades": [
|
||||
"1. Application Sharing & Publishing<br>* Application Sharing (Agent, Workflow, Agent Cluster): Full sharing across all app types<br>* Memory Enabled by Default: Memory auto-enabled on shared apps with disable reminder<br>* Workflow Memory Sharing Rules: Auto-controlled based on memory configuration<br>* Shared Session Web Search Fix: Restored web search for shared apps",
|
||||
"2. Multimodal & Interaction 💬<br>* Voice Input: Model interfaces and apps support voice input<br>* Voice Reply: Apps support voice reply modality<br>* Multimodal Perceptual Memory: Memory system supports visual, audio, image, and file perception<br>* File Display in Chat: Uploaded files display correctly in dry-run and sharing",
|
||||
"3. Platform & Infrastructure ⚙️<br>* i18n Internationalization: Full multi-language multi-region support<br>* Cloud File Storage (OSS + S3): Alibaba Cloud OSS and S3 cloud uploads<br>* Flower Container Monitoring: Celery async task monitoring and management",
|
||||
"4. EndUser Identity Migration 🔐<br>* EndUser Migration from app_id to workspace_id: Identity migrated to workspace level",
|
||||
"5. Episodic Memory 🧠<br>* Episodic Memory Clustering: Community-graph-based clustering with legacy user support",
|
||||
"6. Robustness & Bug Fixes 🔧<br>* MCP Service Deletion 404: Fixed tool endpoint error after MCP removal<br>* App Export Config Mismatch: Exports saved config instead of canvas state<br>* Workflow Duplicate Node ID: Fixed ID conflict on node duplication<br>* Conditional Branch Wiring: Fixed wiring reset after save/refresh<br>* Reply Node Content Loss: Fixed content disappearing on canvas click<br>* Port Connection Rules: Prohibited invalid connection directions<br>* Knowledge Base Status Width: Locked or adaptive column width<br>* Pending Document Preview: Preview support for unparsed documents<br>* Knowledge Base Association Fixes: Consolidated association fixes<br>* Multimodal Conversation Continuity: Fixed single-round limit after multimodal input<br>* Timezone Unification: Env-var controlled unified timezone<br>* Forgetting Strength Precision: Fixed excessive decimal display",
|
||||
"<br>",
|
||||
"v0.2.8 Community delivers major upgrades in application sharing and multimodal interaction, with perceptual memory expanding the platform's cognitive dimensions. Multi-agent collaboration, episodic clustering, and continued platform stability improvements are ahead.",
|
||||
"MemoryBear — Give AI Memory 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.7": {
|
||||
"introduction": {
|
||||
"codeName": "武陵",
|
||||
|
||||
32
api/migrations/versions/05a681a6ca93_202603231611.py
Normal file
32
api/migrations/versions/05a681a6ca93_202603231611.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""202603231611
|
||||
|
||||
Revision ID: 05a681a6ca93
|
||||
Revises: 74b51dfece29
|
||||
Create Date: 2026-03-23 16:12:44.110292
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '05a681a6ca93'
|
||||
down_revision: Union[str, None] = '74b51dfece29'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
@@ -52,6 +52,10 @@ export const getKnowledgeBaseTypeList = async (): Promise<string[]> => {
|
||||
// 如果不是数组,返回空数组
|
||||
return [];
|
||||
};
|
||||
// 获取文件地址
|
||||
export const getFileUrl = (fileId: string) => {
|
||||
return `${apiPrefix}/files/${fileId}`;
|
||||
};
|
||||
// 知识库文档解析类型
|
||||
export const getKnowledgeBaseDocumentParseTypeList = async () => {
|
||||
const response = await request.get(`${apiPrefix}/knowledges/parsertype`);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-13 10:48:41
|
||||
* @Last Modified time: 2026-03-19 18:35:10
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type { AxiosRequestConfig } from 'axios'
|
||||
@@ -218,8 +218,8 @@ export const getExplicitMemory = (end_user_id: string) => {
|
||||
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
|
||||
return request.post(`/memory/explicit-memory/details`, data)
|
||||
}
|
||||
export const getConversations = (end_user_id: string) => {
|
||||
return request.get(`/memory/work/${end_user_id}/conversations`)
|
||||
export const getConversations = (end_user_id: string, page = 1, pagesize = 20) => {
|
||||
return request.get(`/memory/work/${end_user_id}/conversations`, { page, pagesize })
|
||||
}
|
||||
export const getConversationMessages = (end_user_id: string, conversation_id: string) => {
|
||||
return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id })
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-10 16:46:17
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-19 13:38:20
|
||||
* @Last Modified time: 2026-03-19 19:45:40
|
||||
*/
|
||||
import { type FC, useRef, useEffect, useState } from 'react'
|
||||
import clsx from 'clsx'
|
||||
@@ -143,15 +143,20 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
}
|
||||
return (
|
||||
<div key={file.url || file.uid} className="rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:p-1! rb:cursor-pointer" onClick={() => handleDownload(file)}>
|
||||
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>}
|
||||
{(file.type.includes('pdf')) && <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||
></div>}
|
||||
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) && <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||
></div>}
|
||||
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv'))
|
||||
? <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||
></div>
|
||||
:(file.type.includes('pdf'))
|
||||
? <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||
></div>
|
||||
: (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document'))
|
||||
? <div
|
||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>
|
||||
: null
|
||||
}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-10 16:46:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-19 16:05:56
|
||||
* @Last Modified time: 2026-03-19 18:44:51
|
||||
*/
|
||||
import { type FC, useEffect, useMemo } from 'react'
|
||||
import { Flex, Input, Form } from 'antd'
|
||||
import { Flex, Input, Form, Spin } from 'antd'
|
||||
import clsx from 'clsx'
|
||||
|
||||
import SendIcon from '@/assets/images/conversation/send.svg'
|
||||
import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg'
|
||||
@@ -69,6 +70,8 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
onSend(values.message)
|
||||
}
|
||||
|
||||
console.log('previewFileList', previewFileList)
|
||||
|
||||
return (
|
||||
<div className={`rb:absolute rb:bottom-3 rb:left-0 rb:right-0 rb:w-full ${className}`}>
|
||||
<Flex vertical justify="space-between" className="rb:border rb:border-[#DFE4ED] rb:rounded-xl rb:min-h-30">
|
||||
@@ -76,62 +79,78 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
{previewFileList.map((file) => {
|
||||
if (file.type.includes('image')) {
|
||||
return (
|
||||
<div key={file.url || file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
<Spin key={`${file.url || file.uid}_${file.status}`} spinning={file.status === 'uploading'}>
|
||||
<div key={file.url || file.uid} className={clsx("rb:inline-block rb:group rb:relative rb:rounded-lg", {
|
||||
'rb:border rb:border-[#FF5D34]': file.status === 'error'
|
||||
})}>
|
||||
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
</Spin>
|
||||
)
|
||||
}
|
||||
if (file.type.includes('video')) {
|
||||
return (
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
<Spin key={`${file.url || file.uid}_${file.status}`} spinning={file.status === 'uploading'}>
|
||||
<div key={file.url || file.uid} className={clsx("rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg", {
|
||||
'rb:border rb:border-[#FF5D34]': file.status === 'error'
|
||||
})}>
|
||||
<video src={file.url} controls className="rb:w-45 rb:h-15.5 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
</Spin>
|
||||
)
|
||||
}
|
||||
if (file.type.includes('audio')) {
|
||||
return (
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
|
||||
<Spin key={`${file.url || file.uid}_${file.status}`} spinning={file.status === 'uploading'}>
|
||||
<div key={file.url || file.uid} className={clsx("rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2", {
|
||||
'rb:border rb:border-[#FF5D34]': file.status === 'error'
|
||||
})}>
|
||||
<audio src={file.url} controls className="rb:w-45 rb:h-15.5" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
</Spin>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<Spin key={`${file.url || file.uid}_${file.status}`} spinning={file.status === 'uploading'}>
|
||||
<div key={file.url || file.uid} className={clsx("rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5", {
|
||||
'rb:border rb:border-[#FF5D34]': file.status === 'error'
|
||||
})}>
|
||||
{file.type.includes('pdf')
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||
></div>
|
||||
: (file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv'))
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||
></div>
|
||||
: (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document'))
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>
|
||||
: null
|
||||
}
|
||||
<div className="rb:flex-1 rb:w-32.5">
|
||||
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
||||
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.type} · {file.size}</div>
|
||||
</div>
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<div key={file.url || file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
||||
{file.type.includes('pdf')
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||
></div>
|
||||
: (file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv'))
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||
></div>
|
||||
: (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document'))
|
||||
? <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>
|
||||
: null
|
||||
}
|
||||
<div className="rb:flex-1 rb:w-32.5">
|
||||
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
||||
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.type} · {file.size}</div>
|
||||
</div>
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</Spin>
|
||||
)
|
||||
})}
|
||||
</Flex></div>}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-03-17 14:22:25
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-18 15:55:13
|
||||
* @Last Modified time: 2026-03-19 18:59:37
|
||||
*/
|
||||
// Toolbar component for chat input area, supporting file upload, audio recording, and variable configuration
|
||||
import { useRef, forwardRef, useImperativeHandle, type ReactNode, useEffect } from 'react'
|
||||
@@ -49,6 +49,7 @@ interface FormValues {
|
||||
memory?: boolean;
|
||||
}
|
||||
|
||||
const max_file_count = 1;
|
||||
const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||
features,
|
||||
extra,
|
||||
@@ -85,10 +86,18 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||
|
||||
// Append newly uploaded file to the file list when upload is complete
|
||||
const fileChange = (file?: any) => {
|
||||
if (file?.status !== 'done') return
|
||||
const files = [...(queryValues?.files || []), file]
|
||||
form.setFieldValue('files', files)
|
||||
onFilesChange?.(files)
|
||||
console.log('file', file)
|
||||
const lastFiles = form.getFieldValue('files') || [];
|
||||
const index = lastFiles.findIndex((item: any) => item.uid === file.uid)
|
||||
if (index > -1) {
|
||||
lastFiles[index] = file
|
||||
} else {
|
||||
lastFiles.push(file)
|
||||
}
|
||||
form.setFieldValue('files', [...lastFiles])
|
||||
onFilesChange?.([...lastFiles])
|
||||
|
||||
console.log('lastFiles', lastFiles)
|
||||
}
|
||||
|
||||
// Append recorded audio file to the file list and notify parent
|
||||
@@ -128,8 +137,8 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||
key: 'url',
|
||||
label: t('memoryConversation.addRemoteFile'),
|
||||
onClick: () => {
|
||||
if ((queryValues?.files?.length || 0) >= file_upload.max_file_count) {
|
||||
messageApi.warning(t('common.fileNumTip', { num: file_upload.max_file_count }))
|
||||
if ((queryValues?.files?.length || 0) >= max_file_count) {
|
||||
messageApi.warning(t('common.fileNumTip', { num: max_file_count }))
|
||||
return
|
||||
}
|
||||
uploadFileListModalRef.current?.handleOpen()
|
||||
@@ -145,7 +154,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||
onChange={fileChange}
|
||||
requestConfig={uploadRequestConfig}
|
||||
featureConfig={file_upload}
|
||||
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
|
||||
disabled={(queryValues?.files?.length || 0) >= max_file_count}
|
||||
/>
|
||||
)
|
||||
})
|
||||
@@ -177,7 +186,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||
{file_upload?.audio_enabled && file_upload?.allowed_transfer_methods?.includes('local_file') && (
|
||||
<Flex align="center">
|
||||
<AudioRecorder
|
||||
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
|
||||
disabled={(queryValues?.files?.length || 0) >= max_file_count}
|
||||
action={uploadAction}
|
||||
requestConfig={uploadRequestConfig}
|
||||
onRecordingComplete={handleRecordingComplete}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-03-16 19:01:12
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-03-18 18:35:53
|
||||
* @LastEditTime: 2026-03-20 12:12:20
|
||||
*/
|
||||
import { useState, useEffect, useRef, useCallback, type FC } from 'react';
|
||||
import { Spin, Alert, Button, Table, InputNumber, Image } from 'antd';
|
||||
@@ -309,23 +309,64 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const [csvTruncated, setCsvTruncated] = useState(false);
|
||||
|
||||
const isCsvFile = () => getFileExtension() === '.csv';
|
||||
|
||||
// CSV 预览大小限制:1MB
|
||||
const CSV_PREVIEW_SIZE = 1 * 1024 * 1024;
|
||||
// 最大预览行数
|
||||
const MAX_PREVIEW_ROWS = 500;
|
||||
|
||||
const fetchFileBufferWithLimit = async (url: string, maxBytes?: number): Promise<ArrayBuffer> => {
|
||||
const requestUrl = getRequestUrl(url);
|
||||
const headers: Record<string, string> = {
|
||||
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
|
||||
};
|
||||
if (maxBytes) {
|
||||
headers['Range'] = `bytes=0-${maxBytes - 1}`;
|
||||
}
|
||||
const response = await fetch(requestUrl, {
|
||||
credentials: 'include',
|
||||
headers,
|
||||
});
|
||||
if (!response.ok && response.status !== 206) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
return response.arrayBuffer();
|
||||
};
|
||||
|
||||
const loadExcelFile = async () => {
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
setCsvTruncated(false);
|
||||
try {
|
||||
const arrayBuffer = await fetchFileBuffer(fileUrl);
|
||||
|
||||
// CSV 文件需要处理编码问题(可能是 GBK/GB2312)
|
||||
// CSV 文件需要处理编码问题(可能是 GBK/GB2312),且大文件只取前 1MB
|
||||
if (isCsvFile()) {
|
||||
let arrayBuffer: ArrayBuffer;
|
||||
let truncated = false;
|
||||
try {
|
||||
// 先尝试 Range 请求只取前 1MB
|
||||
arrayBuffer = await fetchFileBufferWithLimit(fileUrl, CSV_PREVIEW_SIZE);
|
||||
// 如果返回的数据刚好等于限制大小,说明可能被截断了
|
||||
if (arrayBuffer.byteLength >= CSV_PREVIEW_SIZE) {
|
||||
truncated = true;
|
||||
}
|
||||
} catch {
|
||||
// Range 请求不支持时,全量获取后截断
|
||||
const fullBuffer = await fetchFileBuffer(fileUrl);
|
||||
if (fullBuffer.byteLength > CSV_PREVIEW_SIZE) {
|
||||
arrayBuffer = fullBuffer.slice(0, CSV_PREVIEW_SIZE);
|
||||
truncated = true;
|
||||
} else {
|
||||
arrayBuffer = fullBuffer;
|
||||
}
|
||||
}
|
||||
|
||||
let csvText: string;
|
||||
// 先尝试 UTF-8 解码
|
||||
const utf8Text = new TextDecoder('utf-8').decode(arrayBuffer);
|
||||
// 检测是否有乱码特征(常见的 GBK 被错误解析为 UTF-8 的替换字符)
|
||||
if (utf8Text.includes('\uFFFD') || /[\x80-\xff]/.test(utf8Text.slice(0, 200))) {
|
||||
// 尝试 GBK 解码
|
||||
try {
|
||||
csvText = new TextDecoder('gbk').decode(arrayBuffer);
|
||||
} catch {
|
||||
@@ -334,19 +375,35 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
} else {
|
||||
csvText = utf8Text;
|
||||
}
|
||||
|
||||
// 如果被截断,去掉最后一行不完整的数据
|
||||
if (truncated) {
|
||||
const lastNewline = csvText.lastIndexOf('\n');
|
||||
if (lastNewline > 0) {
|
||||
csvText = csvText.substring(0, lastNewline);
|
||||
}
|
||||
}
|
||||
|
||||
const workbook = XLSX.read(csvText, { type: 'string' });
|
||||
const sheets = workbook.SheetNames.map(sheetName => {
|
||||
const worksheet = workbook.Sheets[sheetName];
|
||||
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||
let data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||
// 限制最大行数
|
||||
if (data.length > MAX_PREVIEW_ROWS + 1) {
|
||||
data = data.slice(0, MAX_PREVIEW_ROWS + 1); // +1 保留表头
|
||||
truncated = true;
|
||||
}
|
||||
return { sheetName, data };
|
||||
});
|
||||
setCsvTruncated(truncated);
|
||||
setExcelData(sheets);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const arrayBuffer = await fetchFileBuffer(fileUrl);
|
||||
const workbook = XLSX.read(arrayBuffer, { type: 'array' });
|
||||
const sheets = workbook.SheetNames.map(sheetName => {
|
||||
const sheets = workbook.SheetNames.map((sheetName: string) => {
|
||||
const worksheet = workbook.Sheets[sheetName];
|
||||
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||
return { sheetName, data };
|
||||
@@ -522,9 +579,14 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Excel 预览 */}
|
||||
{/* Excel/CSV 预览 */}
|
||||
{isExcelFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:flex-1 rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
|
||||
{csvTruncated && (
|
||||
<div className="rb:mb-3 rb:px-3 rb:py-2 rb:bg-yellow-50 rb:border rb:border-yellow-200 rb:rounded rb:text-sm rb:text-yellow-700">
|
||||
文件较大,仅预览前 {MAX_PREVIEW_ROWS} 行数据
|
||||
</div>
|
||||
)}
|
||||
{excelData.map((sheet, index) => (
|
||||
<div key={index} className="rb:mb-6">
|
||||
<h3 className="rb:text-lg rb:font-semibold rb:mb-3">{sheet.sheetName}</h3>
|
||||
@@ -541,6 +603,7 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
scroll={{ x: 'max-content' }}
|
||||
size="small"
|
||||
bordered
|
||||
virtual
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -460,6 +460,7 @@ export const en = {
|
||||
nameInvalid: 'Name cannot start or end with a space',
|
||||
notAllSpaces: 'Cannot be all spaces',
|
||||
view: 'View',
|
||||
callbackUrlInvalid: 'Please enter a valid URL',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: 'search model…',
|
||||
|
||||
@@ -1093,6 +1093,7 @@ export const zh = {
|
||||
nameInvalid: '不能是空格开头或结尾',
|
||||
notAllSpaces: '不能是纯空格',
|
||||
view: '查看',
|
||||
callbackUrlInvalid: '请输入有效的 URL',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: '搜索模型…',
|
||||
|
||||
@@ -183,7 +183,7 @@ const TestChat: FC<TestChatProps> = ({
|
||||
|
||||
const handleSend = () => {
|
||||
if (loading || !application || !message || !message?.trim()) return
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
const variables = toolbarRef.current?.getVariables() || []
|
||||
const { isCanSend, params } = buildVariableParams(variables)
|
||||
if (!isCanSend) return
|
||||
@@ -235,7 +235,7 @@ const TestChat: FC<TestChatProps> = ({
|
||||
|
||||
const handleWorkflowSend = () => {
|
||||
if (loading || !application || !message || !message?.trim()) return
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
const variables = toolbarRef.current?.getVariables() || []
|
||||
const { isCanSend, params } = buildVariableParams(variables)
|
||||
if (!isCanSend) return
|
||||
|
||||
@@ -191,7 +191,7 @@ const Chat: FC<ChatProps> = ({
|
||||
.then(() => {
|
||||
const message = msg
|
||||
if (!message?.trim()) return
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
// Validate required variables before sending
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
@@ -352,7 +352,7 @@ const Chat: FC<ChatProps> = ({
|
||||
.then(() => {
|
||||
const message = msg
|
||||
if (!message || message.trim() === '') return
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
addUserMessage(message, files)
|
||||
setMessage(undefined)
|
||||
toolbarRef.current?.setFiles([])
|
||||
|
||||
@@ -24,7 +24,7 @@ interface FeaturesConfigModalProps {
|
||||
refresh: (value: FeaturesConfigForm) => void;
|
||||
source?: Application['type'];
|
||||
}
|
||||
|
||||
const max_file_count = 1;
|
||||
/**
|
||||
* Modal for copying applications
|
||||
*/
|
||||
@@ -133,7 +133,7 @@ const FeaturesConfigModal = forwardRef<FeaturesConfigModalRef, FeaturesConfigMod
|
||||
</div>
|
||||
<div>
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:py-1">{t('application.maxCount')}</div>
|
||||
{fu.max_file_count} {t('application.unix')}
|
||||
{max_file_count} {t('application.unix')}
|
||||
</div>
|
||||
</Flex>
|
||||
<Button block onClick={handleOpenSettings}>{t('application.setting')}</Button>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-03-05
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-19 15:18:20
|
||||
* @Last Modified time: 2026-03-19 20:19:14
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, InputNumber, Flex, Switch, Row, Col, Radio } from 'antd';
|
||||
@@ -82,28 +82,27 @@ const defaultValues: FileUpload = {
|
||||
"mp3",
|
||||
"wav",
|
||||
"m4a",
|
||||
"ogg",
|
||||
"flac"
|
||||
],
|
||||
document_enabled: false,
|
||||
document_max_size_mb: 100,
|
||||
document_allowed_extensions: [
|
||||
"pdf",
|
||||
"docx",
|
||||
"doc",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"txt",
|
||||
"csv",
|
||||
"json"
|
||||
"json",
|
||||
"md",
|
||||
],
|
||||
video_enabled: false,
|
||||
video_max_size_mb: 100,
|
||||
video_allowed_extensions: [
|
||||
"mp4",
|
||||
"mov",
|
||||
"avi",
|
||||
"webm"
|
||||
],
|
||||
max_file_count: 5,
|
||||
max_file_count: 1,
|
||||
allowed_transfer_methods: 'both'
|
||||
}
|
||||
|
||||
@@ -168,8 +167,8 @@ const FileUploadSettingModal = forwardRef<FileUploadSettingModalRef, FileUploadS
|
||||
</Radio.Group>
|
||||
</Form.Item>
|
||||
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div>
|
||||
<Form.Item label={t('application.maxCount')} name="max_file_count">
|
||||
{/* <div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div> */}
|
||||
<Form.Item label={t('application.maxCount')} name="max_file_count" hidden>
|
||||
<InputNumber min={1} max={20} precision={0} className="rb:w-full!" placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:42
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-18 20:32:54
|
||||
* @Last Modified time: 2026-03-19 18:38:41
|
||||
*/
|
||||
/**
|
||||
* File Upload Component
|
||||
@@ -23,7 +23,7 @@
|
||||
import { useState, useEffect, forwardRef, useImperativeHandle, useMemo } from 'react';
|
||||
import { Upload, Progress, App } from 'antd';
|
||||
import type { UploadProps, UploadFile } from 'antd';
|
||||
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
|
||||
import type { UploadProps as RcUploadProps, RcFile, UploadFileStatus } from 'antd/es/upload/interface';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { request } from '@/utils/request'
|
||||
@@ -221,17 +221,29 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
||||
*/
|
||||
const handleCustomRequest: RcUploadProps['customRequest'] = async (options) => {
|
||||
const { file, onSuccess, onError } = options;
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
const response = await request.uploadFile(action, formData, requestConfig);
|
||||
|
||||
onSuccess?.({data: response});
|
||||
} catch (error) {
|
||||
onError?.(error as Error);
|
||||
if (typeof file === 'string') return;
|
||||
const rcFile = file as RcFile;
|
||||
const formData = new FormData();
|
||||
formData.append('file', rcFile);
|
||||
const fileVo: UploadFile = {
|
||||
uid: rcFile.uid,
|
||||
name: rcFile.name,
|
||||
status: 'uploading' as UploadFileStatus,
|
||||
percent: 0,
|
||||
type: rcFile.type,
|
||||
originFileObj: rcFile,
|
||||
thumbUrl: URL.createObjectURL(rcFile)
|
||||
}
|
||||
onChange?.(fileVo)
|
||||
request.uploadFile(action, formData, requestConfig)
|
||||
.then(res => {
|
||||
onSuccess?.({ data: res });
|
||||
})
|
||||
.catch((error) => {
|
||||
onError?.(error as Error);
|
||||
fileVo.status = 'error'
|
||||
onChange?.(fileVo)
|
||||
})
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:47
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-18 21:10:01
|
||||
* @Last Modified time: 2026-03-19 20:32:32
|
||||
*/
|
||||
/**
|
||||
* Upload File List Modal Component
|
||||
@@ -19,7 +19,10 @@
|
||||
* @component
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
|
||||
import { Form, Input, Select, Button, Flex } from 'antd';
|
||||
import { Form, Input, Select,
|
||||
// Button,
|
||||
Flex
|
||||
} from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { UploadFileListModalRef } from '../types'
|
||||
@@ -105,9 +108,11 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form form={form} layout="vertical">
|
||||
<Form form={form} layout="vertical" initialValues={{ files: [{ type: undefined, url: undefined }] }}>
|
||||
<Form.List name="files">
|
||||
{(fields, { add, remove }) => (
|
||||
{(fields,
|
||||
// { add, remove }
|
||||
) => (
|
||||
<>
|
||||
{/* Render each file entry with type selector and URL input */}
|
||||
{fields.map(({ key, name, ...restField }) => (
|
||||
@@ -116,6 +121,9 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
||||
{...restField}
|
||||
name={[name, 'type']}
|
||||
className="rb:mb-0!"
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseSelect') }
|
||||
]}
|
||||
>
|
||||
<Select
|
||||
placeholder={t('memoryConversation.fileType')}
|
||||
@@ -126,22 +134,25 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
||||
<FormItem
|
||||
{...restField}
|
||||
name={[name, 'url']}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ type: 'url', message: t('common.callbackUrlInvalid') },
|
||||
]}
|
||||
className="rb:mb-0! rb:flex-1!"
|
||||
>
|
||||
<Input placeholder={t('memoryConversation.fileUrl')} />
|
||||
</FormItem>
|
||||
<div
|
||||
{/* <div
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||
onClick={() => remove(name)}
|
||||
></div>
|
||||
></div> */}
|
||||
</Flex>
|
||||
))}
|
||||
<Form.Item noStyle>
|
||||
{/* <Form.Item noStyle>
|
||||
<Button type="dashed" onClick={() => add()} block>
|
||||
+ {t('common.add')}
|
||||
</Button>
|
||||
</Form.Item>
|
||||
</Form.Item> */}
|
||||
</>
|
||||
)}
|
||||
</Form.List>
|
||||
|
||||
@@ -200,7 +200,7 @@ const Conversation: FC = () => {
|
||||
/** Send message and handle streaming response */
|
||||
const handleSend = () => {
|
||||
if (!token || !shareToken) return
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
const variables = toolbarRef.current?.getVariables() || []
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
|
||||
@@ -11,7 +11,7 @@ import { useNavigate, useParams, useLocation } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
|
||||
import { Button, Spin, message, Switch } from 'antd';
|
||||
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase';
|
||||
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk, getFileUrl } from '@/api/knowledgeBase';
|
||||
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import InfoPanel, { type InfoItem } from '../components/InfoPanel';
|
||||
@@ -138,7 +138,7 @@ const DocumentDetails: FC = () => {
|
||||
const response = await getDocumentDetail(documentId);
|
||||
setDocument(response);
|
||||
setInfoItems(formatDocumentInfo(response));
|
||||
const url = `${imagePath}/api/files/${response.file_id}`
|
||||
const url = `${window.location.origin}/api/files/${response.file_id}`;
|
||||
setFileUrl(url);
|
||||
setParserMode(response?.parser_config?.auto_questions || 0)
|
||||
// ChunkList will be called automatically in useEffect based on document.progress
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 18:32:00
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-13 14:51:17
|
||||
* @Last Modified time: 2026-03-19 20:23:42
|
||||
*/
|
||||
/**
|
||||
* Relationship Network Component
|
||||
@@ -287,22 +287,26 @@ const RelationshipNetwork:FC = () => {
|
||||
: (selectedNode as RawCommunityNode).properties.community_id
|
||||
? <div className="rb:p-3 rb:pt-0">
|
||||
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
|
||||
{(selectedNode as RawCommunityNode).properties.name}
|
||||
</div>
|
||||
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
|
||||
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.summary}
|
||||
{(selectedNode as RawCommunityNode).properties.name || selectedNode.id}
|
||||
</div>
|
||||
{(selectedNode as RawCommunityNode).properties.summary && <>
|
||||
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
|
||||
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.summary}
|
||||
</div>
|
||||
</>}
|
||||
<Flex align="center" justify="space-between" className="rb:mt-5!">
|
||||
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
|
||||
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
|
||||
</Flex>
|
||||
|
||||
<Divider className='rb:my-2.5!' />
|
||||
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
|
||||
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) => <li key={index}>{entity}</li>)}
|
||||
</ul>
|
||||
{(selectedNode as RawCommunityNode).properties.core_entities && <>
|
||||
<Divider className='rb:my-2.5!' />
|
||||
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
|
||||
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.core_entities?.map((entity, index) => <li key={index}>{entity}</li>)}
|
||||
</ul>
|
||||
</>}
|
||||
</div>
|
||||
: <>
|
||||
{(selectedNode as Node).name && (
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { type FC, useEffect, useState, useMemo } from 'react'
|
||||
import { type FC, useEffect, useState, useMemo, useRef } from 'react'
|
||||
import clsx from 'clsx'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { Row, Col, Skeleton, Button, Divider, Tooltip } from 'antd'
|
||||
|
||||
import InfiniteScroll from 'react-infinite-scroll-component'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import {
|
||||
getConversations,
|
||||
@@ -34,6 +36,8 @@ const WorkingDetail: FC = () => {
|
||||
const { id } = useParams()
|
||||
const [loading, setLoading] = useState<boolean>(false)
|
||||
const [data, setData] = useState<Conversation[]>([])
|
||||
const [hasMore, setHasMore] = useState<boolean>(true)
|
||||
const pageRef = useRef<number>(1)
|
||||
const [messagesLoading, setMessagesLoading] = useState<boolean>(false)
|
||||
const [messages, setMessages] = useState<ChatItem[]>([])
|
||||
const [detailLoading, setDetailLoading] = useState<boolean>(false)
|
||||
@@ -51,16 +55,30 @@ const WorkingDetail: FC = () => {
|
||||
setSelected(null)
|
||||
setDetail(null)
|
||||
setData([])
|
||||
getConversations(id).then((res) => {
|
||||
const response = res as Conversation[]
|
||||
setData(response)
|
||||
setSelected(response[0] || null)
|
||||
setHasMore(true)
|
||||
pageRef.current = 1
|
||||
getConversations(id, 1).then((res) => {
|
||||
const response = res as { items: Conversation[], page: { hasnext: boolean } }
|
||||
setData(response.items)
|
||||
setSelected(response.items[0] || null)
|
||||
setHasMore(response.page.hasnext)
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
})
|
||||
}
|
||||
|
||||
const loadMore = () => {
|
||||
if (!id) return
|
||||
const nextPage = pageRef.current + 1
|
||||
getConversations(id, nextPage).then((res) => {
|
||||
const response = res as {items: Conversation[], page: { hasnext: boolean }}
|
||||
setData(prev => [...prev, ...response.items])
|
||||
pageRef.current = nextPage
|
||||
setHasMore(response.page.hasnext)
|
||||
})
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!id || !selected || !selected.id) return
|
||||
getDetail(selected.id)
|
||||
@@ -103,22 +121,30 @@ const WorkingDetail: FC = () => {
|
||||
: data.length === 0
|
||||
? <Empty />
|
||||
:(
|
||||
<Row gutter={16} className="rb:h-full">
|
||||
<Row gutter={16}>
|
||||
<Col span={5}>
|
||||
<div className="rb:h-full! rb:border-r rb:border-[#EAECEE] rb:py-3 rb:px-4">
|
||||
{data.map(item => (
|
||||
<div key={item.id} className="rb:mb-3">
|
||||
<Tooltip title={item.title}>
|
||||
<div className={clsx("rb:p-[8px_13px] rb:rounded-lg rb:leading-5 rb:cursor-pointer rb:hover:bg-[#F0F3F8] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap", {
|
||||
'rb:bg-[#FFFFFF] rb:shadow-[0px_2px_4px_0px_rgba(0,0,0,0.15)] rb:font-medium rb:hover:bg-[#FFFFFF]!': item.id === selected?.id,
|
||||
})}
|
||||
onClick={() => setSelected(item)}
|
||||
>
|
||||
{item.title}
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
))}
|
||||
<div id="conversation-list" className="rb:h-[calc(100vh-76px)]! rb:border-r rb:border-[#EAECEE] rb:py-3 rb:px-4 rb:overflow-y-auto">
|
||||
<InfiniteScroll
|
||||
dataLength={data.length}
|
||||
next={loadMore}
|
||||
hasMore={hasMore}
|
||||
loader={null}
|
||||
scrollableTarget="conversation-list"
|
||||
>
|
||||
{data.map(item => (
|
||||
<div key={item.id} className="rb:mb-3">
|
||||
<Tooltip title={item.title}>
|
||||
<div className={clsx("rb:p-[8px_13px] rb:rounded-lg rb:leading-5 rb:cursor-pointer rb:hover:bg-[#F0F3F8] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap", {
|
||||
'rb:bg-[#FFFFFF] rb:shadow-[0px_2px_4px_0px_rgba(0,0,0,0.15)] rb:font-medium rb:hover:bg-[#FFFFFF]!': item.id === selected?.id,
|
||||
})}
|
||||
onClick={() => setSelected(item)}
|
||||
>
|
||||
{item.title}
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
))}
|
||||
</InfiniteScroll>
|
||||
</div>
|
||||
</Col>
|
||||
{selected && <>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:10:56
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-18 20:46:35
|
||||
* @Last Modified time: 2026-03-19 18:41:07
|
||||
*/
|
||||
/**
|
||||
* Workflow Chat Component
|
||||
@@ -151,7 +151,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
||||
|
||||
setLoading(true)
|
||||
const message = msg
|
||||
const files = toolbarRef.current?.getFiles() || []
|
||||
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||
setChatList(prev => [...prev, {
|
||||
role: 'user',
|
||||
content: message,
|
||||
|
||||
@@ -18,8 +18,8 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
const isUserInputRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
// 监听编辑器变化,标记是否为用户输入
|
||||
const removeListener = editor.registerUpdateListener(({ editorState }) => {
|
||||
const removeListener = editor.registerUpdateListener(({ editorState, tags }) => {
|
||||
if (tags.has('programmatic')) return;
|
||||
editorState.read(() => {
|
||||
const root = $getRoot();
|
||||
const textContent = root.getTextContent();
|
||||
@@ -107,7 +107,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
});
|
||||
root.append(paragraph);
|
||||
}
|
||||
}, { discrete: true });
|
||||
}, { discrete: true, tag: 'programmatic' });
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user