diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 807c59f4..864bee4a 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -77,6 +77,7 @@ celery_app.conf.update( # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution + worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING # 结果过期时间 result_expires=3600, # 结果保存1小时 diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 451dcdf7..869eb039 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -8,6 +8,7 @@ from fastapi import APIRouter from . import ( api_key_controller, app_controller, + app_log_controller, auth_controller, chunk_controller, document_controller, @@ -70,6 +71,7 @@ manager_router.include_router(chunk_controller.router) manager_router.include_router(test_controller.router) manager_router.include_router(knowledgeshare_controller.router) manager_router.include_router(app_controller.router) +manager_router.include_router(app_log_controller.router) manager_router.include_router(upload_controller.router) manager_router.include_router(memory_agent_controller.router) manager_router.include_router(memory_dashboard_controller.router) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index e9b539df..3ba9c3a9 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,6 +57,7 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, + api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -65,10 +66,25 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 + - 当提供 api_key 参数时,查找该 API Key 关联的应用 """ + from sqlalchemy import select as sa_select + from app.models.api_key_model import ApiKey + workspace_id = current_user.current_workspace_id service = app_service.AppService(db) + # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 + if api_key: + matched_id = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key == api_key, + ApiKey.resource_id.isnot(None), + ) + ).scalar_one_or_none() + ids = str(matched_id) if matched_id else "" + # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py new file mode 100644 index 00000000..dfd10644 --- /dev/null +++ b/api/app/controllers/app_log_controller.py @@ -0,0 +1,129 @@ +"""应用日志(消息记录)接口""" +import uuid +from typing import Optional + +from fastapi import APIRouter, Depends, Query +from sqlalchemy import select, desc, func +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard +from app.models.conversation_model import Conversation, Message +from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage +from app.schemas.response_schema import PageData, PageMeta +from app.services.app_service import AppService + +router = APIRouter(prefix="/apps", tags=["App Logs"]) +logger = get_business_logger() + + +@router.get("/{app_id}/logs", summary="应用日志 - 会话列表") +@cur_workspace_access_guard() +def list_app_logs( + app_id: uuid.UUID, + page: int = Query(1, ge=1), + pagesize: int = Query(20, ge=1, le=100), + user_id: Optional[str] = None, + is_draft: Optional[bool] = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """查看应用下所有会话记录(分页) + + - 支持按 user_id 筛选 + - 支持按 is_draft 筛选(草稿会话 / 发布会话) + - 按最新更新时间倒序排列 + """ + workspace_id = current_user.current_workspace_id + + # 验证应用访问权限 + service = AppService(db) + service.get_app(app_id, workspace_id) + + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True), + ) + + if user_id: + stmt = stmt.where(Conversation.user_id == user_id) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + total = int(db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(db.scalars(stmt).all()) + + items = [AppLogConversation.model_validate(c) for c in conversations] + meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) + + logger.info( + "查询应用日志会话列表", + extra={"app_id": str(app_id), "total": total, "page": page} + ) + + return success(data=PageData(page=meta, items=items)) + + +@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情") +@cur_workspace_access_guard() +def get_app_log_detail( + app_id: uuid.UUID, + conversation_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """查看某会话的完整消息记录 + + - 返回会话基本信息 + 所有消息(按时间正序) + - 消息 meta_data 包含模型名、token 用量等信息 + """ + workspace_id = current_user.current_workspace_id + + # 验证应用访问权限 + service = AppService(db) + service.get_app(app_id, workspace_id) + + # 查询会话(确保属于该应用和工作空间) + conversation = db.scalars( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True), + ) + ).first() + + if not conversation: + from app.core.exceptions import ResourceNotFoundException + raise ResourceNotFoundException("会话", str(conversation_id)) + + # 查询消息(按时间正序) + messages = list(db.scalars( + select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.created_at) + ).all()) + + detail = AppLogConversationDetail.model_validate(conversation) + detail.messages = [AppLogMessage.model_validate(m) for m in messages] + + logger.info( + "查询应用日志会话详情", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + + return success(data=detail) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index 55149cce..14962a72 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -14,6 +14,9 @@ Routes: import os import uuid from typing import Any +import httpx +import mimetypes +from urllib.parse import urlparse, unquote from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse @@ -290,6 +293,101 @@ async def upload_file_with_share_token( ) +@router.get("/files/info-by-url", response_model=ApiResponse) +async def get_file_info_by_url( + url: str, +): + """ + Get file information by network URL (no authentication required). + + Fetches file metadata from a remote URL via HTTP HEAD request. + Falls back to GET request if HEAD is not supported. + Returns file type, name, and size. + + Args: + url: The network URL of the file. + + Returns: + ApiResponse with file information. + """ + api_logger.info(f"File info by URL request: url={url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Try HEAD request first + response = await client.head(url, follow_redirects=True) + + # If HEAD fails, try GET request (some servers don't support HEAD) + if response.status_code != 200: + api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request") + response = await client.get(url, follow_redirects=True) + + if response.status_code != 200: + api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access file: HTTP {response.status_code}" + ) + + # Get file size from Content-Length header or actual content + file_size = response.headers.get("Content-Length") + if file_size: + file_size = int(file_size) + elif hasattr(response, 'content'): + file_size = len(response.content) + else: + file_size = None + + # Get content type from Content-Type header + content_type = response.headers.get("Content-Type", "application/octet-stream") + # Remove charset and other parameters from content type + content_type = content_type.split(';')[0].strip() + + # Extract filename from Content-Disposition or URL + file_name = None + content_disposition = response.headers.get("Content-Disposition") + if content_disposition and "filename=" in content_disposition: + parts = content_disposition.split("filename=") + if len(parts) > 1: + file_name = parts[1].strip('"').strip("'") + + if not file_name: + parsed_url = urlparse(url) + file_name = unquote(os.path.basename(parsed_url.path)) or "unknown" + + # Extract file extension from filename + _, file_ext = os.path.splitext(file_name) + + # If no extension found, infer from content type + if not file_ext: + ext = mimetypes.guess_extension(content_type) + if ext: + file_ext = ext + file_name = f"{file_name}{file_ext}" + + api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}") + + return success( + data={ + "url": url, + "file_name": file_name, + "file_ext": file_ext.lower() if file_ext else "", + "file_size": file_size, + "content_type": content_type, + }, + msg="File information retrieved successfully" + ) + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Unexpected error: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve file information: {str(e)}" + ) + + @router.get("/files/{file_id}", response_model=Any) async def download_file( request: Request, @@ -697,3 +795,44 @@ async def permanent_download_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve file: {str(e)}" ) + + +@router.get("/files/{file_id}/status", response_model=ApiResponse) +async def get_file_status( + file_id: uuid.UUID, + db: Session = Depends(get_db), +): + """ + Get file upload/processing status (no authentication required). + + This endpoint is used to check if a file (e.g., TTS audio) is ready. + Returns status: pending, completed, or failed. + + Args: + file_id: The UUID of the file. + db: Database session. + + Returns: + ApiResponse with file status and metadata. + """ + api_logger.info(f"File status request: file_id={file_id}") + + # Query file metadata from database + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + api_logger.warning(f"File not found in database: file_id={file_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The file does not exist" + ) + + return success( + data={ + "file_id": str(file_id), + "status": file_metadata.status, + "file_name": file_metadata.file_name, + "file_size": file_metadata.file_size, + "content_type": file_metadata.content_type, + }, + msg="File status retrieved successfully" + ) diff --git a/api/app/core/logging_config.py b/api/app/core/logging_config.py index 28a98a46..d0dda84b 100644 --- a/api/app/core/logging_config.py +++ b/api/app/core/logging_config.py @@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") - # Fallback to console only if file write fails print(f"Warning: Could not write to timing log: {e}") - # Always print to console (backward compatible behavior) - print(f"✓ {step_name}: {duration:.2f}s") + # Always log at INFO level (avoids Celery treating stdout as WARNING) + _timing_logger = logging.getLogger(__name__) + _timing_logger.info(f"✓ {step_name}: {duration:.2f}s") def get_agent_logger(name: str = "agent_service", diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..f782d44b 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -169,8 +169,8 @@ async def write( ) if success: logger.info("Successfully saved all data to Neo4j") - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write( + # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) + await _trigger_clustering_sync( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index d9c04f8b..0fa6a833 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -71,13 +71,11 @@ class LabelPropagationEngine: connector: Neo4jConnector, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id - self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -239,6 +237,7 @@ class LabelPropagationEngine: await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") + await self._generate_community_metadata([new_cid], end_user_id) return # 统计邻居社区分布 @@ -273,7 +272,8 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - await self._generate_community_metadata([target_cid], end_user_id) + # 新实体加入后成员变化,强制重新生成元数据 + await self._generate_community_metadata([target_cid], end_user_id, force=True) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -453,7 +453,7 @@ class LabelPropagationEngine: return lines async def _generate_community_metadata( - self, community_ids: List[str], end_user_id: str + self, community_ids: List[str], end_user_id: str, force: bool = False ) -> None: """ 为一个或多个社区生成并写入元数据。 @@ -462,69 +462,82 @@ class LabelPropagationEngine: 1. 逐个社区调 LLM 生成 name / summary(串行) 2. 收集所有 summary,一次性批量 embed 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata - """ - if not community_ids: - return + Args: + force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后) + """ from app.db import get_db_context from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - # --- 阶段1:并发调 LLM 生成每个社区的 name / summary --- - async def _build_one(cid: str): - members = await self.repo.get_community_members(cid, end_user_id) - if not members: + async def _build_one(cid: str) -> Optional[Dict]: + try: + if not force: + check_embedding = bool(self.embedding_model_id) + if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding): + return None + + members = await self.repo.get_community_members(cid, end_user_id) + if not members: + logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成") + return None + + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else cid[:8] + summary = f"包含实体:{', '.join(all_names)}" + + if self.llm_model_id: + try: + entity_list_str = "\n".join(self._build_entity_lines(members)) + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过80个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}") + + return { + "community_id": cid, + "end_user_id": end_user_id, + "name": name, + "summary": summary, + "core_entities": core_entities, + "summary_embedding": None, + } + except Exception as e: + logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True) return None - sorted_members = sorted( - members, - key=lambda m: m.get("activation_value") or 0, - reverse=True, - ) - core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] - - entity_list_str = "\n".join(self._build_entity_lines(members)) - - # 方案四:注入社区内实体间关系三元组 - relationships = await self.repo.get_community_relationships(cid, end_user_id) - rel_lines = [ - f"- {r['subject']} → {r['predicate']} → {r['object']}" - for r in relationships - if r.get("subject") and r.get("predicate") and r.get("object") - ] - rel_section = ( - f"\n实体间关系:\n" + "\n".join(rel_lines) - if rel_lines else "" - ) - - prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过80个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) - - name, summary = "", "" - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - - return { - "community_id": cid, - "end_user_id": end_user_id, - "name": name, - "summary": summary, - "core_entities": core_entities, - "summary_embedding": None, - } - results = await asyncio.gather( *[_build_one(cid) for cid in community_ids], return_exceptions=True, @@ -537,15 +550,20 @@ class LabelPropagationEngine: metadata_list.append(res) if not metadata_list: + logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}") return # --- 阶段2:批量生成 summary_embedding --- - summaries = [m["summary"] for m in metadata_list] - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - embeddings = await embedder.response(summaries) - for i, meta in enumerate(metadata_list): - meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + if self.embedding_model_id: + try: + summaries = [m["summary"] for m in metadata_list] + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + except Exception as e: + logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) # --- 阶段3:写入(单个 or 批量)--- if len(metadata_list) == 1: @@ -558,17 +576,13 @@ class LabelPropagationEngine: core_entities=m["core_entities"], summary_embedding=m["summary_embedding"], ) - if result: - logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") - else: - logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False") + if not result: + logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败") else: ok = await self.repo.batch_update_community_metadata(metadata_list) - if ok: - logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") - else: - logger.warning(f"[Clustering] 批量写入社区元数据失败") + if not ok: + logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败") @staticmethod def _new_community_id() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4()) \ No newline at end of file diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 248067e7..967f529e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -9,6 +9,7 @@ """ import asyncio +import logging import os import hashlib import json @@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene ScenePatterns ) +logger = logging.getLogger(__name__) + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -706,7 +709,7 @@ class SemanticPruner: # 阈值保护:最高0.9 proportion = float(self.config.pruning_threshold) if proportion > 0.9: - print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") + logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") proportion = 0.9 if proportion < 0.0: proportion = 0.0 @@ -905,7 +908,7 @@ class SemanticPruner: # Safety: avoid empty dataset if not result: - print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") + logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs return result @@ -915,8 +918,7 @@ class SemanticPruner: try: self.run_logs.append(msg) except Exception: - # 任何异常都不影响打印 pass - print(msg) + logger.debug(msg) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..33838061 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -5,8 +5,11 @@ """ import asyncio +import logging from typing import Any, Dict, List, Tuple +logger = logging.getLogger(__name__) + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.models.message_models import DialogData from app.core.models.base import RedBearModelConfig @@ -48,9 +51,9 @@ class EmbeddingGenerator: return await self.embedder_client.response(texts) # 分批并行处理 - print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") + logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] - print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") + logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") # 并行发送所有批次 batch_results = await asyncio.gather(*[ @@ -62,7 +65,7 @@ class EmbeddingGenerator: for batch_result in batch_results: embeddings.extend(batch_result) - print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") + logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") return embeddings async def generate_statement_embeddings( @@ -77,7 +80,7 @@ class EmbeddingGenerator: Returns: 每个对话的陈述句嵌入向量映射列表 """ - print("\n=== 生成陈述句嵌入向量 ===") + logger.debug("=== 生成陈述句嵌入向量 ===") # 收集所有陈述句 all_statements = [] @@ -102,7 +105,7 @@ class EmbeddingGenerator: stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id stmt_embedding_maps[d_idx][stmt_id] = embedding - print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") + logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") return stmt_embedding_maps async def generate_chunk_embeddings( @@ -117,7 +120,7 @@ class EmbeddingGenerator: Returns: 每个对话的分块嵌入向量映射列表 """ - print("\n=== 生成分块嵌入向量 ===") + logger.debug("=== 生成分块嵌入向量 ===") # 收集所有分块 all_chunks = [] @@ -138,7 +141,7 @@ class EmbeddingGenerator: chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id chunk_embedding_maps[d_idx][chunk_id] = embedding - print(f"为 {len(all_chunks)} 个分块生成了嵌入向量") + logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量") return chunk_embedding_maps async def generate_dialog_embeddings( @@ -172,7 +175,7 @@ class EmbeddingGenerator: Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) """ - print("\n=== 生成所有嵌入向量 ===") + logger.debug("=== 生成所有嵌入向量 ===") # 并发生成陈述句和分块嵌入向量 stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( @@ -183,9 +186,7 @@ class EmbeddingGenerator: # 对话嵌入向量(当前跳过) dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) - print( - f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量" - ) + logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量") return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings @@ -201,7 +202,7 @@ class EmbeddingGenerator: Returns: 更新后的三元组映射列表(实体包含嵌入向量) """ - print("\n=== 生成实体嵌入向量 ===") + logger.debug("=== 生成实体嵌入向量 ===") entity_texts: List[str] = [] entity_refs: List[Any] = [] @@ -219,7 +220,7 @@ class EmbeddingGenerator: entity_refs.append(ent) if not entity_texts: - print("没有找到需要生成嵌入向量的实体") + logger.debug("没有找到需要生成嵌入向量的实体") return triplet_maps # 批量生成嵌入向量 @@ -227,13 +228,13 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): setattr(ent, "name_embedding", emb) - print(f"为 {len(entity_refs)} 个实体生成了嵌入向量") + logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量") return triplet_maps @@ -296,7 +297,7 @@ async def embedding_generation_all( Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) """ - print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") + logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") generator = EmbeddingGenerator(embedding_id) diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index b6de28ec..81319789 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -9,7 +9,7 @@ class User(Base): __tablename__ = "users" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - username = Column(String, unique=True, index=True, nullable=False) + username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一 email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..786f7bbe 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,10 +1,13 @@ from typing import List, Optional +import logging from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = logging.getLogger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" @@ -217,10 +220,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") + logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 7273340e..bd448c99 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -300,7 +300,7 @@ class CommunityRepository: ) return bool(result) except Exception as e: - logger.error(f"update_community_metadata failed: {e}") + logger.error(f"update_community_metadata failed: {e}", exc_info=True) return False async def batch_update_community_metadata( diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0cdaeb59..fe1cb252 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1069,6 +1069,7 @@ Graph_Node_query = """ COMMUNITY_NODE_UPSERT = """ MERGE (c:Community {community_id: $community_id}) +ON CREATE SET c.id = $community_id SET c.end_user_id = $end_user_id, c.member_count = $member_count, c.updated_at = datetime() @@ -1175,7 +1176,8 @@ RETURN c.community_id AS community_id, cnt AS member_count UPDATE_COMMUNITY_METADATA = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -SET c.name = $name, +SET c.id = coalesce(c.id, $community_id), + c.name = $name, c.summary = $summary, c.core_entities = $core_entities, c.summary_embedding = $summary_embedding, @@ -1186,7 +1188,8 @@ RETURN c.community_id AS community_id BATCH_UPDATE_COMMUNITY_METADATA = """ UNWIND $communities AS row MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) -SET c.name = row.name, +SET c.id = coalesce(c.id, row.community_id), + c.name = row.name, c.summary = row.summary, c.core_entities = row.core_entities, c.summary_embedding = row.summary_embedding, @@ -1270,6 +1273,40 @@ RETURN startNode(r) = e AS r_from_e """ +CHECK_COMMUNITY_IS_COMPLETE = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL +) AS is_complete +""" + +CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL AND + c.summary_embedding IS NOT NULL +) AS is_complete +""" + +GET_INCOMPLETE_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL + OR c.name = '' OR c.summary = '' +RETURN c.community_id AS community_id +""" + +GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.name = '' + OR c.summary IS NULL OR c.summary = '' + OR c.core_entities IS NULL + OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') +RETURN c.community_id AS community_id +""" # Community keyword search: matches name or summary via fulltext index SEARCH_COMMUNITIES_BY_KEYWORD = """ @@ -1325,39 +1362,4 @@ RETURN s.statement AS statement, c.name AS community_name ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit -""" - -CHECK_COMMUNITY_IS_COMPLETE = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL -) AS is_complete -""" - -CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL AND - c.summary_embedding IS NOT NULL -) AS is_complete -""" - -GET_INCOMPLETE_COMMUNITIES = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL - OR c.name = '' OR c.summary = '' -RETURN c.community_id AS community_id -""" - -GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.name = '' - OR c.summary IS NULL OR c.summary = '' - OR c.core_entities IS NULL - OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') -RETURN c.community_id AS community_id -""" +""" \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..d2c4b9bd 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -162,7 +162,7 @@ async def save_dialog_and_statements_to_neo4j( """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 - schedule_clustering_after_write() 显式触发。 + _trigger_clustering_sync() 显式触发。 Args: dialogue_nodes: List of DialogueNode objects to save @@ -303,16 +303,13 @@ async def save_dialog_and_statements_to_neo4j( return False -def schedule_clustering_after_write( +async def _trigger_clustering_sync( entity_nodes: List, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: """ - 写入 Neo4j 成功后,调度后台聚类任务。 - - 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 - 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + 同步等待聚类完成,避免与其他 LLM 任务并发冲突。 """ if not entity_nodes: return @@ -324,8 +321,8 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] - logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) async def _trigger_clustering( diff --git a/api/app/schemas/app_log_schema.py b/api/app/schemas/app_log_schema.py new file mode 100644 index 00000000..bda78138 --- /dev/null +++ b/api/app/schemas/app_log_schema.py @@ -0,0 +1,53 @@ +"""应用日志(消息记录)Schema""" +import uuid +import datetime +from typing import Optional, Dict, Any, List + +from pydantic import BaseModel, Field, ConfigDict, field_serializer + + +class AppLogMessage(BaseModel): + """单条消息记录""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + conversation_id: uuid.UUID + role: str = Field(description="角色: user / assistant / system") + content: str + meta_data: Optional[Dict[str, Any]] = None + created_at: datetime.datetime + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("meta_data", when_used="json") + def _serialize_meta_data(self, data: Optional[Dict[str, Any]]): + return data or {} + + +class AppLogConversation(BaseModel): + """会话摘要(用于列表)""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + app_id: uuid.UUID + user_id: Optional[str] = None + title: Optional[str] = None + message_count: int = 0 + is_draft: bool + created_at: datetime.datetime + updated_at: datetime.datetime + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class AppLogConversationDetail(AppLogConversation): + """会话详情(包含消息列表)""" + messages: List[AppLogMessage] = Field(default_factory=list) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 6fcf680b..f87e5f5a 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -129,39 +129,12 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [] - for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - files = msg.meta_data.get("files", []) - # 使用 MultimodalService 处理文件 - multimodal_service = MultimodalService(self.db, api_config=model_info) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - history_processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(history_processed_files) - - history.append({ - "role": msg.role, - "content": content - }) # 处理多模态文件 processed_files = None @@ -206,7 +179,8 @@ class AppChatService: # 构建用户消息内容(含多模态文件) human_meta = { - "files": [] + "files": [], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -221,6 +195,13 @@ class AppChatService: "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } + # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url @@ -251,6 +232,7 @@ class AppChatService: "suggested_questions": suggested_questions, "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } async def agnet_chat_stream( @@ -350,39 +332,12 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [] - for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - history_files = msg.meta_data.get("files", []) - # 使用 MultimodalService 处理文件 - multimodal_service = MultimodalService(self.db, api_config=model_info) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in history_files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - history_processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(history_processed_files) - - history.append({ - "role": msg.role, - "content": content - }) # 处理多模态文件 processed_files = None @@ -433,7 +388,7 @@ class AppChatService: elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件(包含 suggested_questions、tts、citations) + # 发送结束事件(包含 suggested_questions、tts、audio_status、citations) end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} sq_config = features_config.get("suggested_questions_after_answer", {}) if isinstance(sq_config, dict) and sq_config.get("enabled"): @@ -443,11 +398,23 @@ class AppChatService: "api_base": api_key_obj.api_base}, {} ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self.agent_service._filter_citations(features_config, []) # 保存消息 human_meta = { - "files":[] + "files":[], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -457,11 +424,16 @@ class AppChatService: if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index f8a01a40..014d96b7 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -274,7 +274,8 @@ class ConversationService: self, conversation_id: uuid.UUID, max_history: Optional[int] = None, - api_config: Optional[ModelInfo] = None + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -282,7 +283,8 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. - api_config (Optional[ModelInfo]): Model API configuration for multimodal processing. + current_provider (Optional[str]): Current provider for file handling. + current_is_omni (Optional[bool]): Current omni flag for file handling. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -292,38 +294,30 @@ class ConversationService: limit=max_history ) - # 转换为字典格式 history = [] for msg in messages: - content = [{"type": "text", "text": msg.content}] - - # 处理 meta_data 中的 files - if msg.meta_data and msg.meta_data.get("files"): - files = msg.meta_data.get("files", []) - if api_config: - # 使用 MultimodalService 处理文件 - from app.services.multimodal_service import MultimodalService - multimodal_service = MultimodalService(self.db, api_config=api_config) - - # 将 files 转换为 FileInput 格式 - file_inputs = [] - for file in files: - from app.schemas.app_schema import FileInput, TransferMethod - file_input = FileInput( - type=file.get("type"), - transfer_method=TransferMethod.REMOTE_URL, - url=file.get("url") - ) - file_inputs.append(file_input) - - processed_files = await multimodal_service.history_process_files(files=file_inputs) - - content.extend(processed_files) - - history.append({ + msg_dict = { "role": msg.role, - "content": content - }) + "content": [{"type": "text", "text": msg.content}] + } + + # 处理用户消息中的多模态文件 + if msg.role == "user" and msg.meta_data: + history_files = msg.meta_data.get("history_files", {}) + + if history_files and current_provider and current_is_omni is not None: + # 检查是否需要重新处理文件 + stored_provider = history_files.get("provider") + stored_is_omni = history_files.get("is_omni") + + # 如果provider或is_omni不匹配,需要重新处理 + if stored_provider != current_provider or stored_is_omni != current_is_omni: + continue + + # provider和is_omni匹配,直接使用存储的内容 + msg_dict["content"].extend(history_files.get("content")) + + history.append(msg_dict) return history @@ -539,6 +533,7 @@ class ConversationService: provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -546,7 +541,8 @@ class ConversationService: model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) @@ -554,15 +550,8 @@ class ConversationService: conversation_messages = await self.get_conversation_history( conversation_id=conversation_id, max_history=20, - api_config=ModelInfo( - model_name=model_name, - provider=provider, - api_key=api_key, - api_base=api_base, - capability=api_config.capability, - is_omni=api_config.is_omni, - model_type=model_type - ) + current_provider=provider, + current_is_omni=is_omni ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 5989f0f8..88a62ee8 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -592,8 +592,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - api_config=model_info, - max_history=10 + max_history=10, + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -661,7 +662,10 @@ class AgentRunService: }) }, files=files, - audio_url=audio_url + processed_files=processed_files, + audio_url=audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) response = { @@ -678,6 +682,7 @@ class AgentRunService: ) if not sub_agent else [], "citations": self._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } logger.info( @@ -830,8 +835,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - api_config=model_info, - max_history=memory_config.get("max_history", 10) + max_history=memory_config.get("max_history", 10), + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -909,10 +915,13 @@ class AgentRunService: "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} }, files=files, - audio_url=stream_audio_url + processed_files=processed_files, + audio_url=stream_audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) - # 12. 发送结束事件(包含 suggested_questions 和 tts) + # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, @@ -923,6 +932,17 @@ class AgentRunService: features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self._filter_citations(features_config, []) yield self._format_sse_event("end", end_data) @@ -1119,14 +1139,17 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, - api_config: ModelInfo | None = None, - max_history: int = 10 + max_history: int = 10, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: - """加载会话历史消息 + """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 + current_provider: 当前模型的provider + current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 @@ -1138,7 +1161,8 @@ class AgentRunService: history = await conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), max_history=max_history, - api_config=api_config + current_provider=current_provider, + current_is_omni=current_is_omni ) logger.debug( @@ -1166,7 +1190,10 @@ class AgentRunService: app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, - audio_url: Optional[str] = None + processed_files: Optional[List[Dict[str, Any]]] = None, + audio_url: Optional[str] = None, + provider: Optional[str] = None, + is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) @@ -1177,6 +1204,11 @@ class AgentRunService: app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 + files: 原始文件输入 + processed_files: 处理后的文件 + audio_url: 音频URL + provider: 模型供应商 + is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService @@ -1186,15 +1218,24 @@ class AgentRunService: # 保存消息(会话已经存在) human_meta = { - "files": [] + "files": [], + "history_files": {} } if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + + # 保存 history_files,包含 provider 和 is_omni 信息 + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": provider, + "is_omni": is_omni + } + # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, @@ -1420,8 +1461,9 @@ class AgentRunService: workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 - 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。 + 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 + 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): @@ -1808,6 +1850,7 @@ class AgentRunService: ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), + "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None @@ -1885,6 +1928,7 @@ class AgentRunService: "results": [{ **r, "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], @@ -2016,6 +2060,7 @@ class AgentRunService: full_content = "" returned_conversation_id = model_conversation_id audio_url = None + audio_status = None citations = [] suggested_questions = [] @@ -2074,6 +2119,7 @@ class AgentRunService: # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") + audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) @@ -2103,6 +2149,7 @@ class AgentRunService: "message": full_content, "elapsed_time": elapsed, "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None @@ -2117,6 +2164,7 @@ class AgentRunService: "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() @@ -2253,6 +2301,7 @@ class AgentRunService: "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..dc064540 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -350,9 +350,6 @@ class MemoryAgentService: langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index e23b1ac3..b5522b74 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User: business_logger.info(f"创建用户: {user.username}, email: {user.email}") try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={"username": user.username, "email": user.email} - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: @@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User: ) try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={ - "username": user.username, - "email": user.email, - "created_by": str(current_user.id) - } - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: diff --git a/api/app/tasks.py b/api/app/tasks.py index 3a237d82..354951c6 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2760,7 +2760,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace patch_fail = 0 for cid in incomplete_ids: try: - await engine._generate_community_metadata(cid, end_user_id) + await engine._generate_community_metadata([cid], end_user_id) patch_ok += 1 except Exception as patch_err: patch_fail += 1 diff --git a/api/migrations/versions/05a681a6ca93_202603231611.py b/api/migrations/versions/05a681a6ca93_202603231611.py new file mode 100644 index 00000000..5ab9c4de --- /dev/null +++ b/api/migrations/versions/05a681a6ca93_202603231611.py @@ -0,0 +1,32 @@ +"""202603231611 + +Revision ID: 05a681a6ca93 +Revises: 74b51dfece29 +Create Date: 2026-03-23 16:12:44.110292 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '05a681a6ca93' +down_revision: Union[str, None] = '74b51dfece29' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + # ### end Alembic commands ###