From bd31aa5abf44778fc3e224392a1f661410ad1f2f Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 20 Mar 2026 16:08:40 +0800 Subject: [PATCH 01/12] feat: remove username uniqueness constraint for community edition - Remove unique=True from username column in User model - Remove username duplicate check in create_user and create_superuser - Add migration to drop unique index on username, keep email unique --- api/app/models/user_model.py | 2 +- api/app/services/user_service.py | 30 ++---------------------------- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index b6de28ec..81319789 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -9,7 +9,7 @@ class User(Base): __tablename__ = "users" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - username = Column(String, unique=True, index=True, nullable=False) + username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一 email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index e23b1ac3..b5522b74 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User: business_logger.info(f"创建用户: {user.username}, email: {user.email}") try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={"username": user.username, "email": user.email} - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: @@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User: ) try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={ - "username": user.username, - "email": user.email, - "created_by": str(current_user.id) - } - ) - - # 检查邮箱是否已注册 + # 检查邮箱是否已注册(邮箱保持唯一) business_logger.debug(f"检查邮箱是否已注册: {user.email}") db_user_by_email = user_repository.get_user_by_email(db, email=user.email) if db_user_by_email: From 911d5e0b34f7c183787aea78c7067955196e10bb Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 20 Mar 2026 17:07:23 +0800 Subject: [PATCH 02/12] feat(app): support searching application list by API Key --- api/app/controllers/app_controller.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index e9b539df..1eb45b89 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,6 +57,7 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, + api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -65,10 +66,28 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 + - 当提供 api_key 参数时,查找该 API Key 关联的应用 """ + from sqlalchemy import select as sa_select + from app.models.api_key_model import ApiKey + workspace_id = current_user.current_workspace_id service = app_service.AppService(db) + # 通过 API Key 搜索:查出关联的 app id,复用 ids 分支返回 + if api_key: + matched = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key.like(f"%{api_key}%"), + ApiKey.resource_id.isnot(None), + ) + ).scalars().all() + app_ids = [str(rid) for rid in matched] + items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) if app_ids else [] + items = [service._convert_to_schema(app, workspace_id) for app in items_orm] + return success(data=items) + # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] From bf5c4628c35541434be0b228141c4dae417e30e7 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 20 Mar 2026 18:02:03 +0800 Subject: [PATCH 03/12] fix: use exact match instead of LIKE for api_key lookup, reuse ids branch flow --- api/app/controllers/app_controller.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 1eb45b89..3ba9c3a9 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -74,19 +74,16 @@ def list_apps( workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - # 通过 API Key 搜索:查出关联的 app id,复用 ids 分支返回 + # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 if api_key: - matched = db.execute( + matched_id = db.execute( sa_select(ApiKey.resource_id).where( ApiKey.workspace_id == workspace_id, - ApiKey.api_key.like(f"%{api_key}%"), + ApiKey.api_key == api_key, ApiKey.resource_id.isnot(None), ) - ).scalars().all() - app_ids = [str(rid) for rid in matched] - items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) if app_ids else [] - items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - return success(data=items) + ).scalar_one_or_none() + ids = str(matched_id) if matched_id else "" # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: From 27672cfaa09aa6a53280d42bf145b6e0f0886b29 Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 23 Mar 2026 12:05:18 +0800 Subject: [PATCH 04/12] feat(app): add app message log query API --- api/app/controllers/__init__.py | 2 + api/app/controllers/app_log_controller.py | 129 ++++++++++++++++++++++ api/app/schemas/app_log_schema.py | 53 +++++++++ 3 files changed, 184 insertions(+) create mode 100644 api/app/controllers/app_log_controller.py create mode 100644 api/app/schemas/app_log_schema.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 585de2ed..50e9e0b0 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -8,6 +8,7 @@ from fastapi import APIRouter from . import ( api_key_controller, app_controller, + app_log_controller, auth_controller, chunk_controller, document_controller, @@ -69,6 +70,7 @@ manager_router.include_router(chunk_controller.router) manager_router.include_router(test_controller.router) manager_router.include_router(knowledgeshare_controller.router) manager_router.include_router(app_controller.router) +manager_router.include_router(app_log_controller.router) manager_router.include_router(upload_controller.router) manager_router.include_router(memory_agent_controller.router) manager_router.include_router(memory_dashboard_controller.router) diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py new file mode 100644 index 00000000..a8f6d532 --- /dev/null +++ b/api/app/controllers/app_log_controller.py @@ -0,0 +1,129 @@ +"""应用日志(消息记录)接口""" +import uuid +from typing import Optional + +from fastapi import APIRouter, Depends +from sqlalchemy import select, desc, func +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard +from app.models.conversation_model import Conversation, Message +from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage +from app.schemas.response_schema import PageData, PageMeta +from app.services.app_service import AppService + +router = APIRouter(prefix="/apps", tags=["App Logs"]) +logger = get_business_logger() + + +@router.get("/{app_id}/logs", summary="应用日志 - 会话列表") +@cur_workspace_access_guard() +def list_app_logs( + app_id: uuid.UUID, + page: int = 1, + pagesize: int = 20, + user_id: Optional[str] = None, + is_draft: Optional[bool] = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """查看应用下所有会话记录(分页) + + - 支持按 user_id 筛选 + - 支持按 is_draft 筛选(草稿会话 / 发布会话) + - 按最新更新时间倒序排列 + """ + workspace_id = current_user.current_workspace_id + + # 验证应用访问权限 + service = AppService(db) + service.get_app(app_id, workspace_id) + + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True), + ) + + if user_id: + stmt = stmt.where(Conversation.user_id == user_id) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + total = int(db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(db.scalars(stmt).all()) + + items = [AppLogConversation.model_validate(c) for c in conversations] + meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) + + logger.info( + "查询应用日志会话列表", + extra={"app_id": str(app_id), "total": total, "page": page} + ) + + return success(data=PageData(page=meta, items=items)) + + +@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情") +@cur_workspace_access_guard() +def get_app_log_detail( + app_id: uuid.UUID, + conversation_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """查看某会话的完整消息记录 + + - 返回会话基本信息 + 所有消息(按时间正序) + - 消息 meta_data 包含模型名、token 用量等信息 + """ + workspace_id = current_user.current_workspace_id + + # 验证应用访问权限 + service = AppService(db) + service.get_app(app_id, workspace_id) + + # 查询会话(确保属于该应用和工作空间) + conversation = db.scalars( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True), + ) + ).first() + + if not conversation: + from app.core.exceptions import ResourceNotFoundException + raise ResourceNotFoundException("会话", str(conversation_id)) + + # 查询消息(按时间正序) + messages = list(db.scalars( + select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.created_at) + ).all()) + + detail = AppLogConversationDetail.model_validate(conversation) + detail.messages = [AppLogMessage.model_validate(m) for m in messages] + + logger.info( + "查询应用日志会话详情", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + + return success(data=detail) diff --git a/api/app/schemas/app_log_schema.py b/api/app/schemas/app_log_schema.py new file mode 100644 index 00000000..e386b5e9 --- /dev/null +++ b/api/app/schemas/app_log_schema.py @@ -0,0 +1,53 @@ +"""应用日志(消息记录)Schema""" +import uuid +import datetime +from typing import Optional, Dict, Any, List + +from pydantic import BaseModel, Field, ConfigDict, field_serializer + + +class AppLogMessage(BaseModel): + """单条消息记录""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + conversation_id: uuid.UUID + role: str = Field(description="角色: user / assistant / system") + content: str + meta_data: Optional[Dict[str, Any]] = None + created_at: datetime.datetime + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("meta_data", when_used="json") + def _serialize_meta_data(self, data: Optional[Dict[str, Any]]): + return data or {} + + +class AppLogConversation(BaseModel): + """会话摘要(用于列表)""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + app_id: uuid.UUID + user_id: Optional[str] = None + title: Optional[str] = None + message_count: int = 0 + is_draft: bool + created_at: datetime.datetime + updated_at: datetime.datetime + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class AppLogConversationDetail(AppLogConversation): + """会话详情(包含消息列表)""" + messages: List[AppLogMessage] = [] From c70ac1339e766e4d4fdd478e233b6263fa23fc82 Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 23 Mar 2026 13:45:56 +0800 Subject: [PATCH 05/12] fix(app): validate pagination params and fix mutable default in schema --- api/app/controllers/app_log_controller.py | 6 +++--- api/app/schemas/app_log_schema.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index a8f6d532..dfd10644 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -2,7 +2,7 @@ import uuid from typing import Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from sqlalchemy import select, desc, func from sqlalchemy.orm import Session @@ -23,8 +23,8 @@ logger = get_business_logger() @cur_workspace_access_guard() def list_app_logs( app_id: uuid.UUID, - page: int = 1, - pagesize: int = 20, + page: int = Query(1, ge=1), + pagesize: int = Query(20, ge=1, le=100), user_id: Optional[str] = None, is_draft: Optional[bool] = None, db: Session = Depends(get_db), diff --git a/api/app/schemas/app_log_schema.py b/api/app/schemas/app_log_schema.py index e386b5e9..bda78138 100644 --- a/api/app/schemas/app_log_schema.py +++ b/api/app/schemas/app_log_schema.py @@ -50,4 +50,4 @@ class AppLogConversation(BaseModel): class AppLogConversationDetail(AppLogConversation): """会话详情(包含消息列表)""" - messages: List[AppLogMessage] = [] + messages: List[AppLogMessage] = Field(default_factory=list) From 93deb286a3097641bc55bb9d813e6d03ff6fd019 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 23 Mar 2026 16:14:46 +0800 Subject: [PATCH 06/12] [add] migration script --- .../versions/05a681a6ca93_202603231611.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 api/migrations/versions/05a681a6ca93_202603231611.py diff --git a/api/migrations/versions/05a681a6ca93_202603231611.py b/api/migrations/versions/05a681a6ca93_202603231611.py new file mode 100644 index 00000000..5ab9c4de --- /dev/null +++ b/api/migrations/versions/05a681a6ca93_202603231611.py @@ -0,0 +1,32 @@ +"""202603231611 + +Revision ID: 05a681a6ca93 +Revises: 74b51dfece29 +Create Date: 2026-03-23 16:12:44.110292 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '05a681a6ca93' +down_revision: Union[str, None] = '74b51dfece29' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + # ### end Alembic commands ### From 31b8a3764eff5d7675c5bd1a244bfa82dd5f5c66 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 23 Mar 2026 16:38:47 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E3=80=90change=E3=80=91=201.Standardize?= =?UTF-8?q?=20log=20specifications=EF=BC=9B2.Cluster=20settings=20trigger?= =?UTF-8?q?=20explicitly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/celery_app.py | 1 + api/app/core/logging_config.py | 5 +- .../core/memory/agent/utils/write_tools.py | 6 +- .../core/memory/llm_tools/openai_client.py | 18 +- .../clustering_engine/label_propagation.py | 166 ++++++++++-------- .../data_preprocessing/data_pruning.py | 10 +- .../embedding_generation.py | 33 ++-- api/app/repositories/neo4j/add_nodes.py | 7 +- .../neo4j/community_repository.py | 2 +- api/app/repositories/neo4j/cypher_queries.py | 78 ++++---- api/app/repositories/neo4j/graph_saver.py | 13 +- api/app/services/memory_agent_service.py | 3 - api/app/tasks.py | 2 +- 13 files changed, 186 insertions(+), 158 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 807c59f4..864bee4a 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -77,6 +77,7 @@ celery_app.conf.update( # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution + worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING # 结果过期时间 result_expires=3600, # 结果保存1小时 diff --git a/api/app/core/logging_config.py b/api/app/core/logging_config.py index 28a98a46..d0dda84b 100644 --- a/api/app/core/logging_config.py +++ b/api/app/core/logging_config.py @@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") - # Fallback to console only if file write fails print(f"Warning: Could not write to timing log: {e}") - # Always print to console (backward compatible behavior) - print(f"✓ {step_name}: {duration:.2f}s") + # Always log at INFO level (avoids Celery treating stdout as WARNING) + _timing_logger = logging.getLogger(__name__) + _timing_logger.info(f"✓ {step_name}: {duration:.2f}s") def get_agent_logger(name: str = "agent_service", diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..f782d44b 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -169,8 +169,8 @@ async def write( ) if success: logger.info("Successfully saved all data to Neo4j") - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write( + # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) + await _trigger_clustering_sync( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..4536f62d 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -82,16 +82,26 @@ class OpenAIClient(LLMClient): LLMClientException: LLM 调用失败 """ try: - template = """{messages}""" - prompt = ChatPromptTemplate.from_template(template) - chain = prompt | self.client + from langchain_core.messages import HumanMessage, SystemMessage, AIMessage + + # 将 dict 消息列表转换为 LangChain 消息对象 + lc_messages = [] + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") + if role == "system": + lc_messages.append(SystemMessage(content=content)) + elif role == "assistant": + lc_messages.append(AIMessage(content=content)) + else: + lc_messages.append(HumanMessage(content=content)) # 添加 Langfuse 回调(如果可用) config = {} if self.langfuse_handler: config["callbacks"] = [self.langfuse_handler] - response = await chain.ainvoke({"messages": messages}, config=config) + response = await self.client.ainvoke(lc_messages, config=config) return response except Exception as e: diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index d9c04f8b..0fa6a833 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -71,13 +71,11 @@ class LabelPropagationEngine: connector: Neo4jConnector, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id - self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -239,6 +237,7 @@ class LabelPropagationEngine: await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") + await self._generate_community_metadata([new_cid], end_user_id) return # 统计邻居社区分布 @@ -273,7 +272,8 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - await self._generate_community_metadata([target_cid], end_user_id) + # 新实体加入后成员变化,强制重新生成元数据 + await self._generate_community_metadata([target_cid], end_user_id, force=True) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -453,7 +453,7 @@ class LabelPropagationEngine: return lines async def _generate_community_metadata( - self, community_ids: List[str], end_user_id: str + self, community_ids: List[str], end_user_id: str, force: bool = False ) -> None: """ 为一个或多个社区生成并写入元数据。 @@ -462,69 +462,82 @@ class LabelPropagationEngine: 1. 逐个社区调 LLM 生成 name / summary(串行) 2. 收集所有 summary,一次性批量 embed 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata - """ - if not community_ids: - return + Args: + force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后) + """ from app.db import get_db_context from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - # --- 阶段1:并发调 LLM 生成每个社区的 name / summary --- - async def _build_one(cid: str): - members = await self.repo.get_community_members(cid, end_user_id) - if not members: + async def _build_one(cid: str) -> Optional[Dict]: + try: + if not force: + check_embedding = bool(self.embedding_model_id) + if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding): + return None + + members = await self.repo.get_community_members(cid, end_user_id) + if not members: + logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成") + return None + + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else cid[:8] + summary = f"包含实体:{', '.join(all_names)}" + + if self.llm_model_id: + try: + entity_list_str = "\n".join(self._build_entity_lines(members)) + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过80个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}") + + return { + "community_id": cid, + "end_user_id": end_user_id, + "name": name, + "summary": summary, + "core_entities": core_entities, + "summary_embedding": None, + } + except Exception as e: + logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True) return None - sorted_members = sorted( - members, - key=lambda m: m.get("activation_value") or 0, - reverse=True, - ) - core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] - - entity_list_str = "\n".join(self._build_entity_lines(members)) - - # 方案四:注入社区内实体间关系三元组 - relationships = await self.repo.get_community_relationships(cid, end_user_id) - rel_lines = [ - f"- {r['subject']} → {r['predicate']} → {r['object']}" - for r in relationships - if r.get("subject") and r.get("predicate") and r.get("object") - ] - rel_section = ( - f"\n实体间关系:\n" + "\n".join(rel_lines) - if rel_lines else "" - ) - - prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过80个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) - - name, summary = "", "" - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - - return { - "community_id": cid, - "end_user_id": end_user_id, - "name": name, - "summary": summary, - "core_entities": core_entities, - "summary_embedding": None, - } - results = await asyncio.gather( *[_build_one(cid) for cid in community_ids], return_exceptions=True, @@ -537,15 +550,20 @@ class LabelPropagationEngine: metadata_list.append(res) if not metadata_list: + logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}") return # --- 阶段2:批量生成 summary_embedding --- - summaries = [m["summary"] for m in metadata_list] - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - embeddings = await embedder.response(summaries) - for i, meta in enumerate(metadata_list): - meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + if self.embedding_model_id: + try: + summaries = [m["summary"] for m in metadata_list] + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + except Exception as e: + logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) # --- 阶段3:写入(单个 or 批量)--- if len(metadata_list) == 1: @@ -558,17 +576,13 @@ class LabelPropagationEngine: core_entities=m["core_entities"], summary_embedding=m["summary_embedding"], ) - if result: - logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") - else: - logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False") + if not result: + logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败") else: ok = await self.repo.batch_update_community_metadata(metadata_list) - if ok: - logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") - else: - logger.warning(f"[Clustering] 批量写入社区元数据失败") + if not ok: + logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败") @staticmethod def _new_community_id() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4()) \ No newline at end of file diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 248067e7..967f529e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -9,6 +9,7 @@ """ import asyncio +import logging import os import hashlib import json @@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene ScenePatterns ) +logger = logging.getLogger(__name__) + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -706,7 +709,7 @@ class SemanticPruner: # 阈值保护:最高0.9 proportion = float(self.config.pruning_threshold) if proportion > 0.9: - print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") + logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") proportion = 0.9 if proportion < 0.0: proportion = 0.0 @@ -905,7 +908,7 @@ class SemanticPruner: # Safety: avoid empty dataset if not result: - print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") + logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs return result @@ -915,8 +918,7 @@ class SemanticPruner: try: self.run_logs.append(msg) except Exception: - # 任何异常都不影响打印 pass - print(msg) + logger.debug(msg) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..33838061 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -5,8 +5,11 @@ """ import asyncio +import logging from typing import Any, Dict, List, Tuple +logger = logging.getLogger(__name__) + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.models.message_models import DialogData from app.core.models.base import RedBearModelConfig @@ -48,9 +51,9 @@ class EmbeddingGenerator: return await self.embedder_client.response(texts) # 分批并行处理 - print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") + logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] - print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") + logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") # 并行发送所有批次 batch_results = await asyncio.gather(*[ @@ -62,7 +65,7 @@ class EmbeddingGenerator: for batch_result in batch_results: embeddings.extend(batch_result) - print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") + logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") return embeddings async def generate_statement_embeddings( @@ -77,7 +80,7 @@ class EmbeddingGenerator: Returns: 每个对话的陈述句嵌入向量映射列表 """ - print("\n=== 生成陈述句嵌入向量 ===") + logger.debug("=== 生成陈述句嵌入向量 ===") # 收集所有陈述句 all_statements = [] @@ -102,7 +105,7 @@ class EmbeddingGenerator: stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id stmt_embedding_maps[d_idx][stmt_id] = embedding - print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") + logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") return stmt_embedding_maps async def generate_chunk_embeddings( @@ -117,7 +120,7 @@ class EmbeddingGenerator: Returns: 每个对话的分块嵌入向量映射列表 """ - print("\n=== 生成分块嵌入向量 ===") + logger.debug("=== 生成分块嵌入向量 ===") # 收集所有分块 all_chunks = [] @@ -138,7 +141,7 @@ class EmbeddingGenerator: chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id chunk_embedding_maps[d_idx][chunk_id] = embedding - print(f"为 {len(all_chunks)} 个分块生成了嵌入向量") + logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量") return chunk_embedding_maps async def generate_dialog_embeddings( @@ -172,7 +175,7 @@ class EmbeddingGenerator: Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) """ - print("\n=== 生成所有嵌入向量 ===") + logger.debug("=== 生成所有嵌入向量 ===") # 并发生成陈述句和分块嵌入向量 stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( @@ -183,9 +186,7 @@ class EmbeddingGenerator: # 对话嵌入向量(当前跳过) dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) - print( - f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量" - ) + logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量") return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings @@ -201,7 +202,7 @@ class EmbeddingGenerator: Returns: 更新后的三元组映射列表(实体包含嵌入向量) """ - print("\n=== 生成实体嵌入向量 ===") + logger.debug("=== 生成实体嵌入向量 ===") entity_texts: List[str] = [] entity_refs: List[Any] = [] @@ -219,7 +220,7 @@ class EmbeddingGenerator: entity_refs.append(ent) if not entity_texts: - print("没有找到需要生成嵌入向量的实体") + logger.debug("没有找到需要生成嵌入向量的实体") return triplet_maps # 批量生成嵌入向量 @@ -227,13 +228,13 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): setattr(ent, "name_embedding", emb) - print(f"为 {len(entity_refs)} 个实体生成了嵌入向量") + logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量") return triplet_maps @@ -296,7 +297,7 @@ async def embedding_generation_all( Returns: (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) """ - print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") + logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") generator = EmbeddingGenerator(embedding_id) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..786f7bbe 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,10 +1,13 @@ from typing import List, Optional +import logging from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = logging.getLogger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" @@ -217,10 +220,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") + logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 7273340e..bd448c99 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -300,7 +300,7 @@ class CommunityRepository: ) return bool(result) except Exception as e: - logger.error(f"update_community_metadata failed: {e}") + logger.error(f"update_community_metadata failed: {e}", exc_info=True) return False async def batch_update_community_metadata( diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0cdaeb59..fe1cb252 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1069,6 +1069,7 @@ Graph_Node_query = """ COMMUNITY_NODE_UPSERT = """ MERGE (c:Community {community_id: $community_id}) +ON CREATE SET c.id = $community_id SET c.end_user_id = $end_user_id, c.member_count = $member_count, c.updated_at = datetime() @@ -1175,7 +1176,8 @@ RETURN c.community_id AS community_id, cnt AS member_count UPDATE_COMMUNITY_METADATA = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -SET c.name = $name, +SET c.id = coalesce(c.id, $community_id), + c.name = $name, c.summary = $summary, c.core_entities = $core_entities, c.summary_embedding = $summary_embedding, @@ -1186,7 +1188,8 @@ RETURN c.community_id AS community_id BATCH_UPDATE_COMMUNITY_METADATA = """ UNWIND $communities AS row MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) -SET c.name = row.name, +SET c.id = coalesce(c.id, row.community_id), + c.name = row.name, c.summary = row.summary, c.core_entities = row.core_entities, c.summary_embedding = row.summary_embedding, @@ -1270,6 +1273,40 @@ RETURN startNode(r) = e AS r_from_e """ +CHECK_COMMUNITY_IS_COMPLETE = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL +) AS is_complete +""" + +CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +RETURN ( + c.name IS NOT NULL AND c.name <> '' AND + c.summary IS NOT NULL AND c.summary <> '' AND + c.core_entities IS NOT NULL AND + c.summary_embedding IS NOT NULL +) AS is_complete +""" + +GET_INCOMPLETE_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL + OR c.name = '' OR c.summary = '' +RETURN c.community_id AS community_id +""" + +GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ +MATCH (c:Community {end_user_id: $end_user_id}) +WHERE c.name IS NULL OR c.name = '' + OR c.summary IS NULL OR c.summary = '' + OR c.core_entities IS NULL + OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') +RETURN c.community_id AS community_id +""" # Community keyword search: matches name or summary via fulltext index SEARCH_COMMUNITIES_BY_KEYWORD = """ @@ -1325,39 +1362,4 @@ RETURN s.statement AS statement, c.name AS community_name ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit -""" - -CHECK_COMMUNITY_IS_COMPLETE = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL -) AS is_complete -""" - -CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """ -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -RETURN ( - c.name IS NOT NULL AND c.name <> '' AND - c.summary IS NOT NULL AND c.summary <> '' AND - c.core_entities IS NOT NULL AND - c.summary_embedding IS NOT NULL -) AS is_complete -""" - -GET_INCOMPLETE_COMMUNITIES = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL - OR c.name = '' OR c.summary = '' -RETURN c.community_id AS community_id -""" - -GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """ -MATCH (c:Community {end_user_id: $end_user_id}) -WHERE c.name IS NULL OR c.name = '' - OR c.summary IS NULL OR c.summary = '' - OR c.core_entities IS NULL - OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)') -RETURN c.community_id AS community_id -""" +""" \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..d2c4b9bd 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -162,7 +162,7 @@ async def save_dialog_and_statements_to_neo4j( """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 - schedule_clustering_after_write() 显式触发。 + _trigger_clustering_sync() 显式触发。 Args: dialogue_nodes: List of DialogueNode objects to save @@ -303,16 +303,13 @@ async def save_dialog_and_statements_to_neo4j( return False -def schedule_clustering_after_write( +async def _trigger_clustering_sync( entity_nodes: List, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: """ - 写入 Neo4j 成功后,调度后台聚类任务。 - - 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 - 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + 同步等待聚类完成,避免与其他 LLM 任务并发冲突。 """ if not entity_nodes: return @@ -324,8 +321,8 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] - logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) async def _trigger_clustering( diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..dc064540 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -350,9 +350,6 @@ class MemoryAgentService: langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, diff --git a/api/app/tasks.py b/api/app/tasks.py index 3a237d82..354951c6 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2760,7 +2760,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace patch_fail = 0 for cid in incomplete_ids: try: - await engine._generate_community_metadata(cid, end_user_id) + await engine._generate_community_metadata([cid], end_user_id) patch_ok += 1 except Exception as patch_err: patch_fail += 1 From a3428c2435271c0ffab64aa2d38f7465a92c9432 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 23 Mar 2026 17:04:30 +0800 Subject: [PATCH 08/12] feat(app): 1. Handling the storage of multimodal messages and adapting to the loading of historical messages for multi-round conversations; 2. Obtain the interface for retrieving the voice status of the reply; 3. File Information Retrieval Interface --- .../controllers/file_storage_controller.py | 187 +++++++++++++++++- api/app/services/app_chat_service.py | 69 ++++--- api/app/services/conversation_service.py | 64 ++++-- api/app/services/draft_run_service.py | 77 ++++++-- 4 files changed, 342 insertions(+), 55 deletions(-) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index ff284f39..14962a72 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -14,6 +14,9 @@ Routes: import os import uuid from typing import Any +import httpx +import mimetypes +from urllib.parse import urlparse, unquote from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse @@ -91,7 +94,7 @@ async def upload_file( if file_size > settings.MAX_FILE_SIZE: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit" ) @@ -172,7 +175,6 @@ async def upload_file_with_share_token( # Get share and release info from share_token service = ReleaseShareService(db) - share_info = service.get_shared_release_info(share_token=share_data.share_token) # Get share object to access app_id share = service.repo.get_by_share_token(share_data.share_token) @@ -291,6 +293,101 @@ async def upload_file_with_share_token( ) +@router.get("/files/info-by-url", response_model=ApiResponse) +async def get_file_info_by_url( + url: str, +): + """ + Get file information by network URL (no authentication required). + + Fetches file metadata from a remote URL via HTTP HEAD request. + Falls back to GET request if HEAD is not supported. + Returns file type, name, and size. + + Args: + url: The network URL of the file. + + Returns: + ApiResponse with file information. + """ + api_logger.info(f"File info by URL request: url={url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Try HEAD request first + response = await client.head(url, follow_redirects=True) + + # If HEAD fails, try GET request (some servers don't support HEAD) + if response.status_code != 200: + api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request") + response = await client.get(url, follow_redirects=True) + + if response.status_code != 200: + api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access file: HTTP {response.status_code}" + ) + + # Get file size from Content-Length header or actual content + file_size = response.headers.get("Content-Length") + if file_size: + file_size = int(file_size) + elif hasattr(response, 'content'): + file_size = len(response.content) + else: + file_size = None + + # Get content type from Content-Type header + content_type = response.headers.get("Content-Type", "application/octet-stream") + # Remove charset and other parameters from content type + content_type = content_type.split(';')[0].strip() + + # Extract filename from Content-Disposition or URL + file_name = None + content_disposition = response.headers.get("Content-Disposition") + if content_disposition and "filename=" in content_disposition: + parts = content_disposition.split("filename=") + if len(parts) > 1: + file_name = parts[1].strip('"').strip("'") + + if not file_name: + parsed_url = urlparse(url) + file_name = unquote(os.path.basename(parsed_url.path)) or "unknown" + + # Extract file extension from filename + _, file_ext = os.path.splitext(file_name) + + # If no extension found, infer from content type + if not file_ext: + ext = mimetypes.guess_extension(content_type) + if ext: + file_ext = ext + file_name = f"{file_name}{file_ext}" + + api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}") + + return success( + data={ + "url": url, + "file_name": file_name, + "file_ext": file_ext.lower() if file_ext else "", + "file_size": file_size, + "content_type": content_type, + }, + msg="File information retrieved successfully" + ) + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Unexpected error: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve file information: {str(e)}" + ) + + @router.get("/files/{file_id}", response_model=Any) async def download_file( request: Request, @@ -499,6 +596,51 @@ async def get_file_url( ) +@router.get("/files/{file_id}/public-url", response_model=ApiResponse) +async def get_permanent_file_url( + file_id: uuid.UUID, + db: Session = Depends(get_db), + storage_service: FileStorageService = Depends(get_file_storage_service), +): + """ + 获取文件的永久公开 URL(无过期时间)。 + + - 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置) + - 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限) + """ + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist") + + if file_metadata.status != "completed": + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File upload not completed, status: {file_metadata.status}") + + file_key = file_metadata.file_key + storage = storage_service.storage + + try: + if isinstance(storage, LocalStorage): + url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}" + else: + url = await storage.get_permanent_url(file_key) + if not url: + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Permanent URL not supported for current storage backend") + + api_logger.info(f"Generated permanent URL: file_id={file_id}") + return success( + data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name}, + msg="Permanent file URL generated successfully" + ) + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to generate permanent URL: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate permanent URL: {str(e)}") + + @router.get("/public/{file_id}", response_model=Any) async def public_download_file( request: Request, @@ -653,3 +795,44 @@ async def permanent_download_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve file: {str(e)}" ) + + +@router.get("/files/{file_id}/status", response_model=ApiResponse) +async def get_file_status( + file_id: uuid.UUID, + db: Session = Depends(get_db), +): + """ + Get file upload/processing status (no authentication required). + + This endpoint is used to check if a file (e.g., TTS audio) is ready. + Returns status: pending, completed, or failed. + + Args: + file_id: The UUID of the file. + db: Database session. + + Returns: + ApiResponse with file status and metadata. + """ + api_logger.info(f"File status request: file_id={file_id}") + + # Query file metadata from database + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + api_logger.warning(f"File not found in database: file_id={file_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The file does not exist" + ) + + return success( + data={ + "file_id": str(file_id), + "status": file_metadata.status, + "file_name": file_metadata.file_name, + "file_size": file_metadata.file_size, + "content_type": file_metadata.content_type, + }, + msg="File status retrieved successfully" + ) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..94a6dddb 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -119,14 +119,12 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] # 处理多模态文件 processed_files = None @@ -180,7 +178,8 @@ class AppChatService: # 构建用户消息内容(含多模态文件) human_meta = { - "files": [] + "files": [], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -195,6 +194,13 @@ class AppChatService: "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } + # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url @@ -225,6 +231,7 @@ class AppChatService: "suggested_questions": suggested_questions, "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } async def agnet_chat_stream( @@ -314,17 +321,12 @@ class AppChatService: ) # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + history = self.conversation_service.get_conversation_history( + conversation_id=conversation_id, + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni + ) # 处理多模态文件 processed_files = None @@ -347,8 +349,14 @@ class AppChatService: total_tokens = 0 text_queue: asyncio.Queue = asyncio.Queue() + api_key_config = { + "model_name": api_key_obj.model_name, + "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, + "provider": api_key_obj.provider, + } stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming( - features_config, api_key_obj, + features_config, api_key_config, text_queue=text_queue, tenant_id=tenant_id, workspace_id=workspace_id ) @@ -378,7 +386,7 @@ class AppChatService: elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件(包含 suggested_questions、tts、citations) + # 发送结束事件(包含 suggested_questions、tts、audio_status、citations) end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} sq_config = features_config.get("suggested_questions_after_answer", {}) if isinstance(sq_config, dict) and sq_config.get("enabled"): @@ -388,11 +396,23 @@ class AppChatService: "api_base": api_key_obj.api_base}, {} ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self.agent_service._filter_citations(features_config, []) # 保存消息 human_meta = { - "files":[] + "files":[], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -402,11 +422,16 @@ class AppChatService: if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index aff5f533..d7bb3595 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -119,25 +119,27 @@ class ConversationService: def get_user_conversations( self, - user_id: uuid.UUID - ) -> list[Conversation]: + user_id: uuid.UUID, + page: int = 1, + page_size: int = 20 + ) -> tuple[list[Conversation], int]: """ - Retrieve recent conversations for a specific user - - This method delegates persistence logic to the repository layer and - applies service-level defaults (e.g. recent conversation limit). + Retrieve recent conversations for a specific user with pagination. Args: user_id (uuid.UUID): Unique identifier of the user. + page (int): Page number (1-based). Defaults to 1. + page_size (int): Number of items per page. Defaults to 20. Returns: - list[Conversation]: A list of recent conversation entities. + tuple[list[Conversation], int]: A list of recent conversation entities and total count. """ - conversations = self.conversation_repo.get_conversation_by_user_id( + conversations, total = self.conversation_repo.get_conversation_by_user_id( user_id, - limit=10 + page=page, + page_size=page_size ) - return conversations + return conversations, total def list_conversations( self, @@ -270,7 +272,9 @@ class ConversationService: def get_conversation_history( self, conversation_id: uuid.UUID, - max_history: Optional[int] = None + max_history: Optional[int] = None, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -278,6 +282,8 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. + current_provider (Optional[str]): Current provider for file handling. + current_is_omni (Optional[bool]): Current omni flag for file handling. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -287,14 +293,30 @@ class ConversationService: limit=max_history ) - # 转换为字典格式 - history = [ - { + history = [] + for msg in messages: + msg_dict = { "role": msg.role, - "content": msg.content + "content": [{"type": "text", "text": msg.content}] } - for msg in messages - ] + + # 处理用户消息中的多模态文件 + if msg.role == "user" and msg.meta_data: + history_files = msg.meta_data.get("history_files", {}) + + if history_files and current_provider and current_is_omni is not None: + # 检查是否需要重新处理文件 + stored_provider = history_files.get("provider") + stored_is_omni = history_files.get("is_omni") + + # 如果provider或is_omni不匹配,需要重新处理 + if stored_provider != current_provider or stored_is_omni != current_is_omni: + continue + + # provider和is_omni匹配,直接使用存储的内容 + msg_dict["content"].extend(history_files.get("content")) + + history.append(msg_dict) return history @@ -510,6 +532,7 @@ class ConversationService: provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -517,14 +540,17 @@ class ConversationService: model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) conversation_messages = self.get_conversation_history( conversation_id=conversation_id, - max_history=20 + max_history=20, + current_provider=provider, + current_is_omni=is_omni ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..13243a92 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -582,7 +582,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=10 + max_history=10, + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -659,7 +661,10 @@ class AgentRunService: }) }, files=files, - audio_url=audio_url + processed_files=processed_files, + audio_url=audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) response = { @@ -676,6 +681,7 @@ class AgentRunService: ) if not sub_agent else [], "citations": self._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } logger.info( @@ -818,7 +824,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=memory_config.get("max_history", 10) + max_history=memory_config.get("max_history", 10), + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -905,10 +913,13 @@ class AgentRunService: "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} }, files=files, - audio_url=stream_audio_url + processed_files=processed_files, + audio_url=stream_audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) - # 12. 发送结束事件(包含 suggested_questions 和 tts) + # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, @@ -919,6 +930,17 @@ class AgentRunService: features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self._filter_citations(features_config, []) yield self._format_sse_event("end", end_data) @@ -1115,13 +1137,17 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, - max_history: int = 10 + max_history: int = 10, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: - """加载会话历史消息 + """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 + current_provider: 当前模型的provider + current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 @@ -1131,7 +1157,9 @@ class AgentRunService: conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), - max_history=max_history + max_history=max_history, + current_provider=current_provider, + current_is_omni=current_is_omni ) logger.debug( @@ -1159,7 +1187,10 @@ class AgentRunService: app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, - audio_url: Optional[str] = None + processed_files: Optional[List[Dict[str, Any]]] = None, + audio_url: Optional[str] = None, + provider: Optional[str] = None, + is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) @@ -1170,6 +1201,11 @@ class AgentRunService: app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 + files: 原始文件输入 + processed_files: 处理后的文件 + audio_url: 音频URL + provider: 模型供应商 + is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService @@ -1179,15 +1215,24 @@ class AgentRunService: # 保存消息(会话已经存在) human_meta = { - "files": [] + "files": [], + "history_files": {} } if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + + # 保存 history_files,包含 provider 和 is_omni 信息 + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": provider, + "is_omni": is_omni + } + # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, @@ -1413,8 +1458,9 @@ class AgentRunService: workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 - 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。 + 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 + 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): @@ -1801,6 +1847,7 @@ class AgentRunService: ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), + "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None @@ -1878,6 +1925,7 @@ class AgentRunService: "results": [{ **r, "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], @@ -2009,6 +2057,7 @@ class AgentRunService: full_content = "" returned_conversation_id = model_conversation_id audio_url = None + audio_status = None citations = [] suggested_questions = [] @@ -2067,6 +2116,7 @@ class AgentRunService: # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") + audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) @@ -2096,6 +2146,7 @@ class AgentRunService: "message": full_content, "elapsed_time": elapsed, "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None @@ -2110,6 +2161,7 @@ class AgentRunService: "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() @@ -2246,6 +2298,7 @@ class AgentRunService: "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") From efeead41b22eff25a78e8831c1fbb0fa86a996d6 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 23 Mar 2026 17:10:49 +0800 Subject: [PATCH 09/12] feat(app): 1. Handling the storage of multimodal messages and adapting to the loading of historical messages for multi-round conversations; 2. Obtain the interface for retrieving the voice status of the reply; 3. File Information Retrieval Interface --- .../controllers/file_storage_controller.py | 187 +++++++++++++++++- api/app/services/app_chat_service.py | 69 ++++--- api/app/services/conversation_service.py | 64 ++++-- api/app/services/draft_run_service.py | 77 ++++++-- 4 files changed, 342 insertions(+), 55 deletions(-) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index ff284f39..14962a72 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -14,6 +14,9 @@ Routes: import os import uuid from typing import Any +import httpx +import mimetypes +from urllib.parse import urlparse, unquote from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse @@ -91,7 +94,7 @@ async def upload_file( if file_size > settings.MAX_FILE_SIZE: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit" ) @@ -172,7 +175,6 @@ async def upload_file_with_share_token( # Get share and release info from share_token service = ReleaseShareService(db) - share_info = service.get_shared_release_info(share_token=share_data.share_token) # Get share object to access app_id share = service.repo.get_by_share_token(share_data.share_token) @@ -291,6 +293,101 @@ async def upload_file_with_share_token( ) +@router.get("/files/info-by-url", response_model=ApiResponse) +async def get_file_info_by_url( + url: str, +): + """ + Get file information by network URL (no authentication required). + + Fetches file metadata from a remote URL via HTTP HEAD request. + Falls back to GET request if HEAD is not supported. + Returns file type, name, and size. + + Args: + url: The network URL of the file. + + Returns: + ApiResponse with file information. + """ + api_logger.info(f"File info by URL request: url={url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Try HEAD request first + response = await client.head(url, follow_redirects=True) + + # If HEAD fails, try GET request (some servers don't support HEAD) + if response.status_code != 200: + api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request") + response = await client.get(url, follow_redirects=True) + + if response.status_code != 200: + api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unable to access file: HTTP {response.status_code}" + ) + + # Get file size from Content-Length header or actual content + file_size = response.headers.get("Content-Length") + if file_size: + file_size = int(file_size) + elif hasattr(response, 'content'): + file_size = len(response.content) + else: + file_size = None + + # Get content type from Content-Type header + content_type = response.headers.get("Content-Type", "application/octet-stream") + # Remove charset and other parameters from content type + content_type = content_type.split(';')[0].strip() + + # Extract filename from Content-Disposition or URL + file_name = None + content_disposition = response.headers.get("Content-Disposition") + if content_disposition and "filename=" in content_disposition: + parts = content_disposition.split("filename=") + if len(parts) > 1: + file_name = parts[1].strip('"').strip("'") + + if not file_name: + parsed_url = urlparse(url) + file_name = unquote(os.path.basename(parsed_url.path)) or "unknown" + + # Extract file extension from filename + _, file_ext = os.path.splitext(file_name) + + # If no extension found, infer from content type + if not file_ext: + ext = mimetypes.guess_extension(content_type) + if ext: + file_ext = ext + file_name = f"{file_name}{file_ext}" + + api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}") + + return success( + data={ + "url": url, + "file_name": file_name, + "file_ext": file_ext.lower() if file_ext else "", + "file_size": file_size, + "content_type": content_type, + }, + msg="File information retrieved successfully" + ) + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Unexpected error: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve file information: {str(e)}" + ) + + @router.get("/files/{file_id}", response_model=Any) async def download_file( request: Request, @@ -499,6 +596,51 @@ async def get_file_url( ) +@router.get("/files/{file_id}/public-url", response_model=ApiResponse) +async def get_permanent_file_url( + file_id: uuid.UUID, + db: Session = Depends(get_db), + storage_service: FileStorageService = Depends(get_file_storage_service), +): + """ + 获取文件的永久公开 URL(无过期时间)。 + + - 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置) + - 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限) + """ + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist") + + if file_metadata.status != "completed": + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File upload not completed, status: {file_metadata.status}") + + file_key = file_metadata.file_key + storage = storage_service.storage + + try: + if isinstance(storage, LocalStorage): + url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}" + else: + url = await storage.get_permanent_url(file_key) + if not url: + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Permanent URL not supported for current storage backend") + + api_logger.info(f"Generated permanent URL: file_id={file_id}") + return success( + data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name}, + msg="Permanent file URL generated successfully" + ) + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to generate permanent URL: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate permanent URL: {str(e)}") + + @router.get("/public/{file_id}", response_model=Any) async def public_download_file( request: Request, @@ -653,3 +795,44 @@ async def permanent_download_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve file: {str(e)}" ) + + +@router.get("/files/{file_id}/status", response_model=ApiResponse) +async def get_file_status( + file_id: uuid.UUID, + db: Session = Depends(get_db), +): + """ + Get file upload/processing status (no authentication required). + + This endpoint is used to check if a file (e.g., TTS audio) is ready. + Returns status: pending, completed, or failed. + + Args: + file_id: The UUID of the file. + db: Database session. + + Returns: + ApiResponse with file status and metadata. + """ + api_logger.info(f"File status request: file_id={file_id}") + + # Query file metadata from database + file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first() + if not file_metadata: + api_logger.warning(f"File not found in database: file_id={file_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The file does not exist" + ) + + return success( + data={ + "file_id": str(file_id), + "status": file_metadata.status, + "file_name": file_metadata.file_name, + "file_size": file_metadata.file_size, + "content_type": file_metadata.content_type, + }, + msg="File status retrieved successfully" + ) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..94a6dddb 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -119,14 +119,12 @@ class AppChatService: ) # 加载历史消息 - messages = self.conversation_service.get_messages( + history = self.conversation_service.get_conversation_history( conversation_id=conversation_id, - limit=10 + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] # 处理多模态文件 processed_files = None @@ -180,7 +178,8 @@ class AppChatService: # 构建用户消息内容(含多模态文件) human_meta = { - "files": [] + "files": [], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -195,6 +194,13 @@ class AppChatService: "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } + # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url @@ -225,6 +231,7 @@ class AppChatService: "suggested_questions": suggested_questions, "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } async def agnet_chat_stream( @@ -314,17 +321,12 @@ class AppChatService: ) # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + history = self.conversation_service.get_conversation_history( + conversation_id=conversation_id, + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni + ) # 处理多模态文件 processed_files = None @@ -347,8 +349,14 @@ class AppChatService: total_tokens = 0 text_queue: asyncio.Queue = asyncio.Queue() + api_key_config = { + "model_name": api_key_obj.model_name, + "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, + "provider": api_key_obj.provider, + } stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming( - features_config, api_key_obj, + features_config, api_key_config, text_queue=text_queue, tenant_id=tenant_id, workspace_id=workspace_id ) @@ -378,7 +386,7 @@ class AppChatService: elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件(包含 suggested_questions、tts、citations) + # 发送结束事件(包含 suggested_questions、tts、audio_status、citations) end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} sq_config = features_config.get("suggested_questions_after_answer", {}) if isinstance(sq_config, dict) and sq_config.get("enabled"): @@ -388,11 +396,23 @@ class AppChatService: "api_base": api_key_obj.api_base}, {} ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self.agent_service._filter_citations(features_config, []) # 保存消息 human_meta = { - "files":[] + "files":[], + "history_files": {} } assistant_meta = { "model": api_key_obj.model_name, @@ -402,11 +422,16 @@ class AppChatService: if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": api_key_obj.provider, + "is_omni": api_key_obj.is_omni + } if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index aff5f533..d7bb3595 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -119,25 +119,27 @@ class ConversationService: def get_user_conversations( self, - user_id: uuid.UUID - ) -> list[Conversation]: + user_id: uuid.UUID, + page: int = 1, + page_size: int = 20 + ) -> tuple[list[Conversation], int]: """ - Retrieve recent conversations for a specific user - - This method delegates persistence logic to the repository layer and - applies service-level defaults (e.g. recent conversation limit). + Retrieve recent conversations for a specific user with pagination. Args: user_id (uuid.UUID): Unique identifier of the user. + page (int): Page number (1-based). Defaults to 1. + page_size (int): Number of items per page. Defaults to 20. Returns: - list[Conversation]: A list of recent conversation entities. + tuple[list[Conversation], int]: A list of recent conversation entities and total count. """ - conversations = self.conversation_repo.get_conversation_by_user_id( + conversations, total = self.conversation_repo.get_conversation_by_user_id( user_id, - limit=10 + page=page, + page_size=page_size ) - return conversations + return conversations, total def list_conversations( self, @@ -270,7 +272,9 @@ class ConversationService: def get_conversation_history( self, conversation_id: uuid.UUID, - max_history: Optional[int] = None + max_history: Optional[int] = None, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[dict]: """ Retrieve historical conversation messages formatted as dictionaries. @@ -278,6 +282,8 @@ class ConversationService: Args: conversation_id (uuid.UUID): Conversation UUID. max_history (Optional[int]): Maximum number of messages to retrieve. + current_provider (Optional[str]): Current provider for file handling. + current_is_omni (Optional[bool]): Current omni flag for file handling. Returns: List[dict]: List of message dictionaries with keys 'role' and 'content'. @@ -287,14 +293,30 @@ class ConversationService: limit=max_history ) - # 转换为字典格式 - history = [ - { + history = [] + for msg in messages: + msg_dict = { "role": msg.role, - "content": msg.content + "content": [{"type": "text", "text": msg.content}] } - for msg in messages - ] + + # 处理用户消息中的多模态文件 + if msg.role == "user" and msg.meta_data: + history_files = msg.meta_data.get("history_files", {}) + + if history_files and current_provider and current_is_omni is not None: + # 检查是否需要重新处理文件 + stored_provider = history_files.get("provider") + stored_is_omni = history_files.get("is_omni") + + # 如果provider或is_omni不匹配,需要重新处理 + if stored_provider != current_provider or stored_is_omni != current_is_omni: + continue + + # provider和is_omni匹配,直接使用存储的内容 + msg_dict["content"].extend(history_files.get("content")) + + history.append(msg_dict) return history @@ -510,6 +532,7 @@ class ConversationService: provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -517,14 +540,17 @@ class ConversationService: model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) conversation_messages = self.get_conversation_history( conversation_id=conversation_id, - max_history=20 + max_history=20, + current_provider=provider, + current_is_omni=is_omni ) if len(conversation_messages) == 0: return ConversationOut( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..13243a92 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -582,7 +582,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=10 + max_history=10, + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -659,7 +661,10 @@ class AgentRunService: }) }, files=files, - audio_url=audio_url + processed_files=processed_files, + audio_url=audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) response = { @@ -676,6 +681,7 @@ class AgentRunService: ) if not sub_agent else [], "citations": self._filter_citations(features_config, result.get("citations", [])), "audio_url": audio_url, + "audio_status": "pending" } logger.info( @@ -818,7 +824,9 @@ class AgentRunService: # 6. 加载历史消息 history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=memory_config.get("max_history", 10) + max_history=memory_config.get("max_history", 10), + current_provider=api_key_config.get("provider"), + current_is_omni=api_key_config.get("is_omni", False) ) # 6. 处理多模态文件 @@ -905,10 +913,13 @@ class AgentRunService: "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} }, files=files, - audio_url=stream_audio_url + processed_files=processed_files, + audio_url=stream_audio_url, + provider=api_key_config.get("provider"), + is_omni=api_key_config.get("is_omni", False) ) - # 12. 发送结束事件(包含 suggested_questions 和 tts) + # 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status) end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, @@ -919,6 +930,17 @@ class AgentRunService: features_config, full_content, api_key_config, effective_params ) end_data["audio_url"] = stream_audio_url + # 检查TTS是否已完成(非阻塞,不取消任务) + audio_status = "pending" + if tts_task is not None and tts_task.done(): + # 任务已完成,检查是否有异常 + try: + tts_task.result() + audio_status = "completed" + except Exception as e: + logger.warning(f"TTS任务异常: {e}") + audio_status = "failed" + end_data["audio_status"] = audio_status if stream_audio_url else None end_data["citations"] = self._filter_citations(features_config, []) yield self._format_sse_event("end", end_data) @@ -1115,13 +1137,17 @@ class AgentRunService: async def _load_conversation_history( self, conversation_id: str, - max_history: int = 10 + max_history: int = 10, + current_provider: Optional[str] = None, + current_is_omni: Optional[bool] = None ) -> List[Dict[str, str]]: - """加载会话历史消息 + """加载会话历史消息,并根据当前模型配置处理多模态文件 Args: conversation_id: 会话ID max_history: 最大历史消息数量 + current_provider: 当前模型的provider + current_is_omni: 当前模型的is_omni Returns: List[Dict]: 历史消息列表 @@ -1131,7 +1157,9 @@ class AgentRunService: conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( conversation_id=uuid.UUID(conversation_id), - max_history=max_history + max_history=max_history, + current_provider=current_provider, + current_is_omni=current_is_omni ) logger.debug( @@ -1159,7 +1187,10 @@ class AgentRunService: app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None, files: Optional[List[FileInput]] = None, - audio_url: Optional[str] = None + processed_files: Optional[List[Dict[str, Any]]] = None, + audio_url: Optional[str] = None, + provider: Optional[str] = None, + is_omni: Optional[bool] = None ) -> None: """保存会话消息(会话已通过 _ensure_conversation 确保存在) @@ -1170,6 +1201,11 @@ class AgentRunService: app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) meta_data: token消耗 + files: 原始文件输入 + processed_files: 处理后的文件 + audio_url: 音频URL + provider: 模型供应商 + is_omni: 是否为全模态模型 """ try: from app.services.conversation_service import ConversationService @@ -1179,15 +1215,24 @@ class AgentRunService: # 保存消息(会话已经存在) human_meta = { - "files": [] + "files": [], + "history_files": {} } if files: for f in files: - # url = await MultimodalService(self.db).get_file_url(f) human_meta["files"].append({ "type": f.type, "url": f.url }) + + # 保存 history_files,包含 provider 和 is_omni 信息 + if processed_files: + human_meta["history_files"] = { + "content": processed_files, + "provider": provider, + "is_omni": is_omni + } + # 保存用户消息 conversation_service.add_message( conversation_id=conv_uuid, @@ -1413,8 +1458,9 @@ class AgentRunService: workspace_id: Optional[uuid.UUID] = None, ) -> tuple[Optional[str], Optional[asyncio.Task]]: """文本流式输入并行合成音频。 - 返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。 + 返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。 调用方向 text_queue put 文本 chunk,结束时 put None。 + 前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。 """ tts_config = features_config.get("text_to_speech", {}) if not isinstance(tts_config, dict) or not tts_config.get("enabled"): @@ -1801,6 +1847,7 @@ class AgentRunService: ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "audio_url": result.get("audio_url"), + "audio_status": result.get("audio_status"), "citations": result.get("citations", []), "suggested_questions": result.get("suggested_questions", []), "error": None @@ -1878,6 +1925,7 @@ class AgentRunService: "results": [{ **r, "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), } for r in results], @@ -2009,6 +2057,7 @@ class AgentRunService: full_content = "" returned_conversation_id = model_conversation_id audio_url = None + audio_status = None citations = [] suggested_questions = [] @@ -2067,6 +2116,7 @@ class AgentRunService: # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") + audio_status = event_data.get("audio_status") citations = event_data.get("citations", []) suggested_questions = event_data.get("suggested_questions", []) @@ -2096,6 +2146,7 @@ class AgentRunService: "message": full_content, "elapsed_time": elapsed, "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "error": None @@ -2110,6 +2161,7 @@ class AgentRunService: "elapsed_time": elapsed, "message_length": len(full_content), "audio_url": audio_url, + "audio_status": audio_status, "citations": citations, "suggested_questions": suggested_questions, "timestamp": time.time() @@ -2246,6 +2298,7 @@ class AgentRunService: "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), + "audio_status": r.get("audio_status"), "citations": r.get("citations", []), "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") From 48e2e613bbb322650f78c26dc497d26eacd83bdc Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 23 Mar 2026 17:34:54 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=E3=80=90change=E3=80=91Restore=20chat=20?= =?UTF-8?q?mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/memory/llm_tools/openai_client.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 4536f62d..0b75de3a 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -82,26 +82,17 @@ class OpenAIClient(LLMClient): LLMClientException: LLM 调用失败 """ try: - from langchain_core.messages import HumanMessage, SystemMessage, AIMessage - - # 将 dict 消息列表转换为 LangChain 消息对象 - lc_messages = [] - for m in messages: - role = m.get("role", "user") - content = m.get("content", "") - if role == "system": - lc_messages.append(SystemMessage(content=content)) - elif role == "assistant": - lc_messages.append(AIMessage(content=content)) - else: - lc_messages.append(HumanMessage(content=content)) + # 使用 Langfuse 回调(如果可用) + template = """{messages}""" + prompt = ChatPromptTemplate.from_template(template) + chain = prompt | self.client # 添加 Langfuse 回调(如果可用) config = {} if self.langfuse_handler: config["callbacks"] = [self.langfuse_handler] - response = await self.client.ainvoke(lc_messages, config=config) + response = await chain.ainvoke({"messages": messages}, config=config) return response except Exception as e: From f86448f4bfd6f45e976230088dec9721835a55ba Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 23 Mar 2026 17:39:17 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E3=80=90change=E3=80=91=20Restore=20chat?= =?UTF-8?q?=20mode=201?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/core/memory/llm_tools/openai_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 0b75de3a..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -82,7 +82,6 @@ class OpenAIClient(LLMClient): LLMClientException: LLM 调用失败 """ try: - # 使用 Langfuse 回调(如果可用) template = """{messages}""" prompt = ChatPromptTemplate.from_template(template) chain = prompt | self.client From 6348304b7dc9dbf125156de3dd47c37231ef9d73 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 23 Mar 2026 18:52:23 +0800 Subject: [PATCH 12/12] fix(app): Error occurred while processing the experience sharing and loading the historical messages. --- api/app/services/app_chat_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 645de979..f87e5f5a 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -129,7 +129,7 @@ class AppChatService: ) # 加载历史消息 - history = self.conversation_service.get_conversation_history( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, max_history=10, current_provider=api_key_obj.provider, @@ -332,7 +332,7 @@ class AppChatService: ) # 加载历史消息 - history = self.conversation_service.get_conversation_history( + history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, max_history=10, current_provider=api_key_obj.provider,