diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index aac2aa84..dfb63dad 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,6 +1,8 @@ import asyncio import json import logging +import os +import threading from typing import Dict, Any, Optional import redis.asyncio as redis @@ -21,6 +23,50 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) +_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}" + +# Thread-local storage for connection pools. +# Each thread (and each forked process) gets its own pool to avoid +# "Future attached to a different loop" errors in Celery --pool=threads +# and stale connections after fork in --pool=prefork. +_thread_local = threading.local() + + +def get_thread_safe_redis() -> redis.StrictRedis: + """Return a Redis client whose connection pool is bound to the current + thread, process **and** event loop. + + The pool is recreated when: + - The PID changes (fork, Celery --pool=prefork) + - The thread has no pool yet (Celery --pool=threads) + - The previously-cached event loop has been closed (Celery tasks call + ``_shutdown_loop_gracefully`` which closes the loop after each run) + """ + current_pid = os.getpid() + cached_loop = getattr(_thread_local, "loop", None) + loop_stale = cached_loop is not None and cached_loop.is_closed() + + if not hasattr(_thread_local, "pool") \ + or getattr(_thread_local, "pid", None) != current_pid \ + or loop_stale: + _thread_local.pid = current_pid + # Python 3.10+: get_event_loop() raises RuntimeError in threads + # where no loop has been set yet (e.g. Celery --pool=threads). + try: + _thread_local.loop = asyncio.get_event_loop() + except RuntimeError: + _thread_local.loop = None + _thread_local.pool = ConnectionPool.from_url( + _REDIS_URL, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=5, + health_check_interval=30, + ) + + return redis.StrictRedis(connection_pool=_thread_local.pool) + async def get_redis_connection(): """获取Redis连接""" @@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None): val = json.dumps(val, ensure_ascii=False) if expire is not None: - # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) else: - # 设置永久键值 await aio_redis.set(key, val) except Exception as e: logger.error(f"Redis set错误: {str(e)}") diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py index 6b162cdd..e0008353 100644 --- a/api/app/cache/memory/activity_stats_cache.py +++ b/api/app/cache/memory/activity_stats_cache.py @@ -10,7 +10,7 @@ import logging from typing import Optional, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ class ActivityStatsCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -90,7 +90,7 @@ class ActivityStatsCache: """ try: key = cls._get_key(workspace_id) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中活动统计缓存: {key}") @@ -116,7 +116,7 @@ class ActivityStatsCache: """ try: key = cls._get_key(workspace_id) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除活动统计缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py index 108e2a37..2881f06c 100644 --- a/api/app/cache/memory/interest_memory.py +++ b/api/app/cache/memory/interest_memory.py @@ -9,7 +9,7 @@ import logging from typing import Optional, List, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class InterestMemoryCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -86,7 +86,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中兴趣分布缓存: {key}") @@ -114,7 +114,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3ba9c3a9..74991bcf 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,7 +57,6 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, - api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -66,7 +65,7 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - - 当提供 api_key 参数时,查找该 API Key 关联的应用 + - search 参数支持:应用名称模糊搜索、API Key 精确搜索 """ from sqlalchemy import select as sa_select from app.models.api_key_model import ApiKey @@ -74,23 +73,34 @@ def list_apps( workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 - if api_key: - matched_id = db.execute( - sa_select(ApiKey.resource_id).where( - ApiKey.workspace_id == workspace_id, - ApiKey.api_key == api_key, - ApiKey.resource_id.isnot(None), - ) - ).scalar_one_or_none() - ids = str(matched_id) if matched_id else "" + # 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索 + if search: + search = search.strip() + # 尝试作为 API Key 精确匹配(API Key 通常较长) + if len(search) >= 10: + matched_id = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key == search, + ApiKey.resource_id.isnot(None), + ) + ).scalar_one_or_none() + if matched_id: + # 找到 API Key,直接返回关联的应用 + ids = str(matched_id) - # 当 ids 存在且不为 None 时,根据 ids 获取应用 + # 当 ids 存在时,根据 ids 获取应用(不分页) if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] - items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) - items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - return success(data=items) + if app_ids: + items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) + items = [service._convert_to_schema(app, workspace_id) for app in items_orm] + # 返回标准分页格式 + meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False) + return success(data=PageData(page=meta, items=items)) + # ids 为空时,返回空列表 + meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False) + return success(data=PageData(page=meta, items=[])) # 正常分页查询 items_orm, total = app_service.list_apps( diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index dfd10644..92b5becd 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -3,17 +3,16 @@ import uuid from typing import Optional from fastapi import APIRouter, Depends, Query -from sqlalchemy import select, desc, func from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user, cur_workspace_access_guard -from app.models.conversation_model import Conversation, Message -from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage +from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail from app.schemas.response_schema import PageData, PageMeta from app.services.app_service import AppService +from app.services.app_log_service import AppLogService router = APIRouter(prefix="/apps", tags=["App Logs"]) logger = get_business_logger() @@ -25,52 +24,35 @@ def list_app_logs( app_id: uuid.UUID, page: int = Query(1, ge=1), pagesize: int = Query(20, ge=1, le=100), - user_id: Optional[str] = None, is_draft: Optional[bool] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): """查看应用下所有会话记录(分页) - - 支持按 user_id 筛选 - 支持按 is_draft 筛选(草稿会话 / 发布会话) - 按最新更新时间倒序排列 + - 所有人(包括共享者和被共享者)都只能查看自己的会话记录 """ workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - stmt = select(Conversation).where( - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), + # 使用 Service 层查询 + log_service = AppLogService(db) + conversations, total = log_service.list_conversations( + app_id=app_id, + workspace_id=workspace_id, + page=page, + pagesize=pagesize, + is_draft=is_draft ) - if user_id: - stmt = stmt.where(Conversation.user_id == user_id) - - if is_draft is not None: - stmt = stmt.where(Conversation.is_draft == is_draft) - - total = int(db.execute( - select(func.count()).select_from(stmt.subquery()) - ).scalar_one()) - - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - - conversations = list(db.scalars(stmt).all()) - items = [AppLogConversation.model_validate(c) for c in conversations] meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) - logger.info( - "查询应用日志会话列表", - extra={"app_id": str(app_id), "total": total, "page": page} - ) - return success(data=PageData(page=meta, items=items)) @@ -86,44 +68,22 @@ def get_app_log_detail( - 返回会话基本信息 + 所有消息(按时间正序) - 消息 meta_data 包含模型名、token 用量等信息 + - 所有人(包括共享者和被共享者)都只能查看自己的会话详情 """ workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - # 查询会话(确保属于该应用和工作空间) - conversation = db.scalars( - select(Conversation).where( - Conversation.id == conversation_id, - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), - ) - ).first() - - if not conversation: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("会话", str(conversation_id)) - - # 查询消息(按时间正序) - messages = list(db.scalars( - select(Message) - .where(Message.conversation_id == conversation_id) - .order_by(Message.created_at) - ).all()) - - detail = AppLogConversationDetail.model_validate(conversation) - detail.messages = [AppLogMessage.model_validate(m) for m in messages] - - logger.info( - "查询应用日志会话详情", - extra={ - "app_id": str(app_id), - "conversation_id": str(conversation_id), - "message_count": len(messages) - } + # 使用 Service 层查询 + log_service = AppLogService(db) + conversation = log_service.get_conversation_detail( + app_id=app_id, + conversation_id=conversation_id, + workspace_id=workspace_id ) + detail = AppLogConversationDetail.model_validate(conversation) + return success(data=detail) diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py index 2b5ef72f..51ce92b3 100644 --- a/api/app/controllers/memory_forget_controller.py +++ b/api/app/controllers/memory_forget_controller.py @@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import ( ForgettingCurveRequest, ForgettingCurveResponse, ForgettingCurvePoint, + PendingNodesResponse, ) from app.schemas.response_schema import ApiResponse from app.services.memory_forget_service import MemoryForgetService @@ -308,6 +309,100 @@ async def get_forgetting_stats( return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) +@router.get("/pending-nodes", response_model=ApiResponse) +async def get_pending_nodes( + end_user_id: str, + page: int = 1, + pagesize: int = 10, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + 此接口独立分页,与 /stats 接口分离。 + + Args: + end_user_id: 组ID(即 end_user_id,必填) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含待遗忘节点列表和分页信息的响应 + + Examples: + - 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10 + - 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20 + + Notes: + - page 从1开始,pagesize 必须大于0 + - 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}} + """ + workspace_id = current_user.current_workspace_id + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 验证 end_user_id 必填 + if not end_user_id: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id") + return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required") + + # 通过 end_user_id 获取关联的 config_id + try: + from app.services.memory_agent_service import get_end_user_connected_config + + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + config_id = resolve_config_id(config_id, db) + + if config_id is None: + api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None") + + api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}") + except ValueError as e: + api_logger.warning(f"获取终端用户配置失败: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + except Exception as e: + api_logger.error(f"获取终端用户配置时发生错误: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e)) + + # 验证分页参数 + if page < 1: + return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1") + if pagesize < 1: + return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: " + f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}" + ) + + try: + # 调用服务层获取待遗忘节点列表 + result = await forget_service.get_pending_nodes( + db=db, + end_user_id=end_user_id, + config_id=config_id, + page=page, + pagesize=pagesize + ) + + # 构建响应 + response_data = PendingNodesResponse(**result) + + return success(data=response_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"获取待遗忘节点列表失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e)) + + @router.post("/forgetting_curve", response_model=ApiResponse) async def get_forgetting_curve( request: ForgettingCurveRequest, diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 26902b07..c10ad14b 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.workflow_service import WorkflowService +from app.models.file_metadata_model import FileMetadata from app.utils.app_config_utils import workflow_config_4_app_release, \ agent_config_4_app_release, multi_agent_config_4_app_release @@ -259,8 +260,41 @@ def get_conversation( conv_service = ConversationService(db) messages = conv_service.get_messages(conversation_id) - # 构建响应 - conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump() + file_ids = [] + message_file_id_map = {} + + # 第一次遍历:解析 audio_url,收集所有有效的 file_id + for idx, m in enumerate(messages): + if m.role == "assistant" and m.meta_data: + audio_url = m.meta_data.get("audio_url") + if not audio_url: + continue + try: + file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1]) + except (ValueError, IndexError): + # audio_url 无法解析为 UUID,标记为 unknown + m.meta_data["audio_status"] = "unknown" + continue + + file_ids.append(file_id) + message_file_id_map[idx] = file_id + + # 批量查询所有相关的 FileMetadata + file_status_map = {} + if file_ids: + file_metas = ( + db.query(FileMetadata) + .filter(FileMetadata.id.in_(set(file_ids))) + .all() + ) + file_status_map = {fm.id: fm.status for fm in file_metas} + + # 第二次遍历:将查询结果映射回消息 + for idx, file_id in message_file_id_map.items(): + m = messages[idx] + m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown") + + conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json") conv_dict["messages"] = [ conversation_schema.Message.model_validate(m) for m in messages ] @@ -320,6 +354,16 @@ async def chat( other_id=other_id, original_user_id=user_id ) + + # Only extract and set memory_config_id when the end user doesn't have one yet + if not new_end_user.memory_config_id: + from app.services.memory_config_service import MemoryConfigService + memory_config_service = MemoryConfigService(db) + memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {}) + if memory_config_id: + new_end_user.memory_config_id = memory_config_id + db.commit() + db.refresh(new_end_user) end_user_id = str(new_end_user.id) # appid = share.app_id diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 32a911f9..d4573464 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -91,7 +91,7 @@ async def chat( app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) other_id = payload.user_id - workspace_id = app.workspace_id + workspace_id = api_key_auth.workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=app.id, diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 16213690..cc16a6b4 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -111,6 +111,18 @@ def get_current_user_info( break api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") + + # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限 + if current_user.external_source: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() + if source and source.permissions: + result_schema.permissions = source.permissions + else: + result_schema.permissions = [] + else: + result_schema.permissions = ["all"] + return success(data=result_schema, msg=t("users.info.get_success")) @@ -135,7 +147,6 @@ def get_tenant_superusers( return success(data=superusers_schema, msg=t("users.list.superusers_success")) - @router.get("/{user_id}", response_model=ApiResponse) def get_user_info_by_id( user_id: uuid.UUID, diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 55bcb8ba..1f437973 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -151,11 +151,6 @@ async def write( # Step 3: Save all data to Neo4j database step_start = time.time() - from app.repositories.neo4j.create_indexes import create_fulltext_indexes - try: - await create_fulltext_indexes() - except Exception as e: - logger.error(f"Error creating indexes: {e}", exc_info=True) # 添加死锁重试机制 max_retries = 3 @@ -279,5 +274,21 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + # Close LLM/Embedder underlying httpx clients to prevent + # 'RuntimeError: Event loop is closed' during garbage collection + for client_obj in (llm_client, embedder_client): + try: + underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None) + if underlying is None: + continue + # Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model + inner = getattr(underlying, '_model', underlying) + # LangChain OpenAI models expose async_client (httpx.AsyncClient) + http_client = getattr(inner, 'async_client', None) + if http_client is not None and hasattr(http_client, 'aclose'): + await http_client.aclose() + except Exception: + pass + logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..c70fef5f 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.info(f"OpenAI 客户端初始化完成: type={type_}") + logger.debug(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 967f529e..5390197a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene logger = logging.getLogger(__name__) +def message_has_files(message: "ConversationMessage") -> bool: + """检查消息是否包含文件。 + + Args: + message: 待检查的消息对象 + + Returns: + bool: 如果消息包含文件则返回 True,否则返回 False + """ + return message.files and len(message.files) > 0 + + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -128,7 +140,7 @@ class SemanticPruner: 1. 空消息 2. 场景特定填充词库精确匹配 3. 常见寒暄精确匹配 - 4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了") + 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了") 5. 纯表情/标点 """ t = message.msg.strip() @@ -482,6 +494,11 @@ class SemanticPruner: """ to_delete_ids: set = set() for m in msgs: + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}") + continue + # 填充检测优先:先判断是否为填充,再看 LLM 保护 if self._is_filler_message(m): to_delete_ids.add(id(m)) @@ -549,6 +566,11 @@ class SemanticPruner: to_delete_ids: set = set() for m in msgs: msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}") + continue # 第一优先级:填充消息无论模式直接删除,不参与后续场景判断 if self._is_filler_message(m): @@ -801,6 +823,12 @@ class SemanticPruner: for idx, m in enumerate(msgs): msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与分类 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}") + llm_protected_msgs.append((idx, m)) # 放入保护列表 + continue if self._msg_matches_tokens(m, preserve_tokens): llm_protected_msgs.append((idx, m)) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 5ef7db0e..b20112a2 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -182,7 +182,7 @@ class ExtractionOrchestrator: list[StatementEntityEdge], list[EntityEntityEdge], list[PerceptualEdge], - dict + list[DialogData] ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -295,6 +295,7 @@ class ExtractionOrchestrator: statement_entity_edges, entity_entity_edges, dialog_data_list, + dedup_details, ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -306,6 +307,11 @@ class ExtractionOrchestrator: dialog_data_list, ) + # 步骤 7: 同步用户别名到数据库表(仅正式模式) + if not is_pilot_run: + logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表") + await self._update_end_user_other_name(entity_nodes, dialog_data_list) + logger.info(f"知识提取流水线运行完成({mode_str})") return ( dialogue_nodes, @@ -1399,7 +1405,8 @@ class ExtractionOrchestrator: logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") else: first_alias = current_aliases[0].strip() if current_aliases else "" - if first_alias: + # 确保 first_alias 不是占位名称 + if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES: db.add(EndUserInfo( end_user_id=end_user_uuid, other_name=first_alias, @@ -1415,29 +1422,33 @@ class ExtractionOrchestrator: + # 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中 + USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'} + def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) - 这个方法直接返回 LLM 提取的别名列表,不做任何修改。 + 这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。 第一个别名将被用作 other_name。 Args: entity_nodes: 实体节点列表 Returns: - 别名列表(保持 LLM 提取的原始顺序) + 别名列表(保持 LLM 提取的原始顺序,已过滤占位名称) """ - USER_NAMES = {'用户', '我', 'User', 'I'} for entity in entity_nodes: - if getattr(entity, 'name', '').strip() in USER_NAMES: + if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES: aliases = getattr(entity, 'aliases', []) or [] - logger.debug(f"提取到用户别名(原始顺序): {aliases}") - return aliases + # 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}") + return filtered return [] async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: - """从 Neo4j 查询用户实体的完整 aliases 列表""" + """从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)""" cypher = """ MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] @@ -1451,7 +1462,10 @@ class ExtractionOrchestrator: aliases = result[0].get('aliases') or [] if not aliases: logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") - return aliases + return [] + # 过滤掉占位名称,防止历史脏数据传播 + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + return filtered def _resolve_other_name( self, @@ -1463,14 +1477,25 @@ class ExtractionOrchestrator: 决定 other_name 是否需要更新,返回新值;无需更新返回 None。 决策规则: - - 为空 → 用本次对话第一个别名 + - 为空或为占位名称 → 用本次对话第一个别名 - 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除) - 否则 → 保持不变(返回 None) + + 注意:返回值不允许是占位名称("用户"、"我"、"User"、"I") """ - if not current or not current.strip(): - return current_aliases[0].strip() if current_aliases else None + # 当前值为空或为占位名称时,需要更新 + if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES: + candidate = current_aliases[0].strip() if current_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate if current not in neo4j_aliases: - return neo4j_aliases[0].strip() if neo4j_aliases else None + candidate = neo4j_aliases[0].strip() if neo4j_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate return None @@ -1492,6 +1517,7 @@ class ExtractionOrchestrator: list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[DialogData], dict ]: """ @@ -1555,6 +1581,8 @@ class ExtractionOrchestrator: statement_chunk_edges, dedup_statement_entity_edges, dedup_entity_entity_edges, + dialog_data_list, + dedup_details, ) final_entity_nodes = dedup_entity_nodes @@ -1562,7 +1590,16 @@ class ExtractionOrchestrator: final_entity_entity_edges = dedup_entity_entity_edges else: # 正式模式:执行完整的两阶段去重 - result_tuple = await dedup_layers_and_merge_and_return( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dedup_details, + ) = await dedup_layers_and_merge_and_return( dialogue_nodes, chunk_nodes, statement_nodes, @@ -1576,21 +1613,21 @@ class ExtractionOrchestrator: llm_client=self.llm_client, ) - # 解包返回值 - ( - _, - _, - _, - final_entity_nodes, - _, - final_statement_entity_edges, - final_entity_entity_edges, - dedup_details, - ) = result_tuple - # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) + result_tuple = ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dialog_data_list, + dedup_details, + ) + logger.info( f"去重后: {len(final_entity_nodes)} 个实体节点, " f"{len(final_statement_entity_edges)} 条陈述句-实体边, " diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index f9f2f45c..6605532d 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement. {% if language == "zh" %} - 用户实体的 name 字段:使用 "用户" 或 "我" - 用户的真实姓名:放入 aliases + - **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等** - 示例: * "我叫李明" → name="用户", aliases=["李明"] + * ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases) + * ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases) {% else %} - User entity name field: use "User" or "I" - User's real name: put in aliases + - **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.** - Examples: * "I'm John" → name="User", aliases=["John"] + * ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases) + * ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases) {% endif %} diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 1db86fef..c6c6ec48 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -44,6 +44,8 @@ class OSSStorage(StorageBackend): access_key_id: str, access_key_secret: str, bucket_name: str, + connect_timeout: int = 30, + multipart_threshold: int = 10 * 1024 * 1024, # 10MB ): """ Initialize the OSSStorage backend. @@ -53,6 +55,8 @@ class OSSStorage(StorageBackend): access_key_id: The Aliyun access key ID. access_key_secret: The Aliyun access key secret. bucket_name: The name of the OSS bucket. + connect_timeout: Connection timeout in seconds (default: 30). + multipart_threshold: File size threshold for multipart upload (default: 10MB). Raises: StorageConfigError: If any required configuration is missing. @@ -69,10 +73,17 @@ class OSSStorage(StorageBackend): self.endpoint = endpoint self.bucket_name = bucket_name + self.multipart_threshold = multipart_threshold try: auth = oss2.Auth(access_key_id, access_key_secret) - self.bucket = oss2.Bucket(auth, endpoint, bucket_name) + # 设置超时和重试 + self.bucket = oss2.Bucket( + auth, + endpoint, + bucket_name, + connect_timeout=connect_timeout + ) logger.info( f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}" ) @@ -108,21 +119,38 @@ class OSSStorage(StorageBackend): if content_type: headers["Content-Type"] = content_type - self.bucket.put_object(file_key, content, headers=headers if headers else None) + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB per part + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers if headers else None) + logger.info(f"File uploaded to OSS successfully: {file_key}") return file_key except OssError as e: logger.error(f"OSS error uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e.message}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -135,28 +163,73 @@ class OSSStorage(StorageBackend): ) -> int: """Upload from async stream to OSS. Returns total bytes written.""" buf = io.BytesIO() + headers = {"Content-Type": content_type} if content_type else None + upload_id = None + try: + # 收集流数据 + total_size = 0 async for chunk in stream: + if not chunk: + continue buf.write(chunk) + total_size += len(chunk) + content = buf.getvalue() - headers = {"Content-Type": content_type} if content_type else None - self.bucket.put_object(file_key, content, headers=headers) - logger.info(f"File stream uploaded to OSS successfully: {file_key}") - return len(content) + + if not content: + raise StorageUploadError( + message="Empty stream content", + file_key=file_key, + ) + + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers) + + logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)") + return total_size + except OssError as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"OSS error stream uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e.message}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"Failed to stream upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) + finally: + buf.close() async def download(self, file_key: str) -> bytes: """ @@ -182,14 +255,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error downloading file {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e.message}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to download file from OSS {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -215,14 +288,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error deleting file {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e.message}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to delete file from OSS {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 6df6df51..b437d021 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -99,7 +99,7 @@ class SimpleMCPClient: # 建立 SSE 连接 response = await self._session.get(self.server_url) - if response.status != 200: + if response.status not in (200, 202): error_text = await response.text() raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") @@ -190,7 +190,9 @@ class SimpleMCPClient: try: async with self._session.post(self._endpoint_url, json=request) as response: - if response.status != 200: + # MCP SSE 协议:POST 请求返回 200 或 202 均为正常 + # 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回 + if response.status not in (200, 202): error_text = await response.text() raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") @@ -205,7 +207,7 @@ class SimpleMCPClient: raise MCPConnectionError("endpoint URL 未初始化") async with self._session.post(self._endpoint_url, json=notification) as response: - if response.status != 200: + if response.status not in (200, 202): logger.warning(f"通知发送失败: {response.status}") async def _initialize_modelscope_session(self): diff --git a/api/app/main.py b/api/app/main.py index f4c23ca8..9e501f11 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,5 +1,6 @@ import os import subprocess +from app.repositories.neo4j.create_indexes import create_all_indexes from contextlib import asynccontextmanager from fastapi import FastAPI, APIRouter @@ -60,8 +61,10 @@ async def lifespan(app: FastAPI): logger.warning(f"加载预定义模型时出错: {str(e)}") else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") - + await create_all_indexes() logger.info("应用程序启动完成") + + yield # 应用关闭事件 logger.info("应用程序正在关闭") diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index 81319789..c0b17d14 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -19,9 +19,12 @@ class User(Base): last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空 # SSO 外部关联字段 - external_id = Column(String(100), nullable=True) # 外部用户ID + external_id = Column(String(100), nullable=True) # 外部用户 ID external_source = Column(String(50), nullable=True) # 来源系统 + # 用户联系方式 + phone = Column(String(50), nullable=True) # 用户电话 + # 用户语言偏好 preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文 diff --git a/api/app/repositories/conversation_repository.py b/api/app/repositories/conversation_repository.py index 90f2d6ec..0676a255 100644 --- a/api/app/repositories/conversation_repository.py +++ b/api/app/repositories/conversation_repository.py @@ -199,6 +199,96 @@ class ConversationRepository: ) return conversations, total + def list_app_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + is_draft: Optional[bool] = None, + page: int = 1, + pagesize: int = 20 + ) -> tuple[list[Conversation], int]: + """ + 查询应用日志会话列表(带分页和过滤) + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + is_draft: 是否草稿会话(None 表示不过滤) + page: 页码(从 1 开始) + pagesize: 每页数量 + + Returns: + Tuple[List[Conversation], int]: (会话列表,总数) + """ + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + # Calculate total number of records + total = int(self.db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + # Apply pagination + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(self.db.scalars(stmt).all()) + + logger.info( + "Listed app conversations successfully", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "returned": len(conversations), + "total": total + } + ) + return conversations, total + + def get_conversation_for_app_log( + self, + conversation_id: uuid.UUID, + app_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询应用日志的会话详情 + + Args: + conversation_id: 会话 ID + app_id: 应用 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info(f"Fetching conversation for app log: {conversation_id}") + + stmt = select(Conversation).where( + Conversation.id == conversation_id, + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + conversation = self.db.scalars(stmt).first() + + if not conversation: + logger.warning(f"Conversation not found: {conversation_id}") + raise ResourceNotFoundException("会话", str(conversation_id)) + + logger.info(f"Conversation fetched successfully: {conversation_id}") + return conversation + def soft_delete_conversation_by_conversation_id( self, conversation_id: uuid.UUID, @@ -290,6 +380,34 @@ class MessageRepository: self.db.add(message) return message + def get_messages_by_conversation( + self, + conversation_id: uuid.UUID + ) -> list[Message]: + """ + 查询会话的所有消息(按时间正序) + + Args: + conversation_id: 会话 ID + + Returns: + List[Message]: 消息列表 + """ + stmt = select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at) + + messages = list(self.db.scalars(stmt).all()) + + logger.info( + "Fetched messages for conversation", + extra={ + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + return messages + def get_message_by_conversation_id( self, conversation_id: uuid.UUID, diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 3c1dd16f..aad80707 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -132,6 +132,82 @@ class EndUserRepository: db_logger.error(f"获取或创建终端用户时出错: {str(e)}") raise + def get_or_create_end_user_with_config( + self, + app_id: Optional[uuid.UUID], + workspace_id: uuid.UUID, + other_id: str, + memory_config_id: Optional[uuid.UUID] = None, + other_name: Optional[str] = None + ) -> EndUser: + """获取或创建终端用户,并在单次事务中关联记忆配置。 + + 与 get_or_create_end_user 类似,但额外支持在创建/获取时 + 一并设置 memory_config_id,避免多次提交。 + + Args: + app_id: 应用ID(可为 None) + workspace_id: 工作空间ID + other_id: 第三方ID + memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置) + other_name: 用户名称(用于创建 EndUserInfo) + + Returns: + EndUser: 终端用户对象(已关联记忆配置) + """ + try: + end_user = ( + self.db.query(EndUser) + .filter( + EndUser.workspace_id == workspace_id, + EndUser.other_id == other_id + ) + .order_by(EndUser.created_at.asc()) + .first() + ) + + if end_user: + db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}") + if app_id is not None: + end_user.app_id = app_id + if memory_config_id and not end_user.memory_config_id: + end_user.memory_config_id = memory_config_id + self.db.commit() + self.db.refresh(end_user) + return end_user + + # 创建新用户 + end_user = EndUser( + app_id=app_id, + workspace_id=workspace_id, + other_id=other_id, + memory_config_id=memory_config_id, + ) + self.db.add(end_user) + self.db.flush() + + end_user_info = EndUserInfo( + end_user_id=end_user.id, + other_name=other_name or "", + aliases=[], + meta_data={} + ) + self.db.add(end_user_info) + + self.db.commit() + self.db.refresh(end_user) + + db_logger.info( + f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, " + f"memory_config_id={memory_config_id}" + ) + return end_user + + except Exception as e: + self.db.rollback() + db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}") + raise + def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据ID获取终端用户(用于缓存操作) @@ -515,6 +591,51 @@ class EndUserRepository: ) raise + def batch_update_memory_config_id_by_app( + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID + ) -> int: + """批量更新应用下所有终端用户的 memory_config_id + + Args: + app_id: 应用ID + memory_config_id: 新的记忆配置ID + + Returns: + int: 更新的终端用户数量 + + Raises: + Exception: 数据库操作失败时抛出 + """ + try: + from sqlalchemy import update + + stmt = ( + update(EndUser) + .where(EndUser.app_id == app_id) + .values(memory_config_id=memory_config_id) + ) + + result = self.db.execute(stmt) + self.db.commit() + + updated_count = result.rowcount + + db_logger.info( + f"批量更新终端用户记忆配置: app_id={app_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + + return updated_count + except Exception as e: + self.db.rollback() + db_logger.error( + f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + f"memory_config_id={memory_config_id}, error={str(e)}" + ) + raise + def count_by_memory_config_id( self, memory_config_id: uuid.UUID diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index d9e94117..5132aa09 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -1,62 +1,47 @@ +import asyncio from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - async def create_fulltext_indexes(): """Create full-text indexes for keyword search with BM25 scoring.""" connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Full-Text Indexes (for keyword search)") - print("=" * 70) + # 创建 Statements 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: statementsFulltext") + """) # # 创建 Dialogues 索引 # await connector.execute_query(""" # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } # """) - # 创建 Entities 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: entitiesFulltext") + """) # 创建 Chunks 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: chunksFulltext") + """) # 创建 MemorySummary 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: summariesFulltext") - + """) # 创建 Community 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) - print("✓ Created: communitiesFulltext") - print("\nFull-text indexes created successfully with BM25 support.") - except Exception as e: - print(f"✗ Error creating full-text indexes: {e}") finally: await connector.close() - - async def create_vector_indexes(): """Create vector indexes for fast embedding similarity search. @@ -65,12 +50,7 @@ async def create_vector_indexes(): """ connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Vector Indexes (for embedding search)") - print("=" * 70) - print("Note: Adjust vector.dimensions if using different embedding model") - print(" Current setting: 1024 dimensions (for bge-m3)") - print() + # Statement embedding index await connector.execute_query(""" @@ -82,7 +62,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: statement_embedding_index") + # Chunk embedding index await connector.execute_query(""" @@ -94,7 +74,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: chunk_embedding_index") + # Entity name embedding index await connector.execute_query(""" @@ -106,7 +86,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: entity_embedding_index") + # Memory summary embedding index await connector.execute_query(""" @@ -118,8 +98,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: summary_embedding_index") - + # Community summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS @@ -129,8 +108,7 @@ async def create_vector_indexes(): `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """) - print("✓ Created: community_summary_embedding_index") + """) # Dialogue embedding index (optional) await connector.execute_query(""" @@ -142,91 +120,15 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: dialogue_embedding_index") - - # Community summary embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS - FOR (c:Community) - ON c.summary_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: community_summary_embedding_index") - print("\nVector indexes created successfully!") - print("\nExpected performance improvement:") - print(" Before: ~1.4s for embedding search") - print(" After: ~0.05-0.2s for embedding search (10-30x faster!)") - - except Exception as e: - print(f"✗ Error creating vector indexes: {e}") finally: await connector.close() - - -async def create_config_id_indexes(): - """Create indexes on config_id fields for improved query performance. - - These indexes enable fast filtering of nodes by configuration ID, - which is essential for configuration isolation and multi-tenant scenarios. - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Config ID Indexes") - print("=" * 70) - - # Dialogue.config_id index - await connector.execute_query(""" - CREATE INDEX dialogue_config_id_index IF NOT EXISTS - FOR (d:Dialogue) ON (d.config_id) - """) - print("✓ Created: dialogue_config_id_index") - - # Statement.config_id index - await connector.execute_query(""" - CREATE INDEX statement_config_id_index IF NOT EXISTS - FOR (s:Statement) ON (s.config_id) - """) - print("✓ Created: statement_config_id_index") - - # ExtractedEntity.config_id index - await connector.execute_query(""" - CREATE INDEX entity_config_id_index IF NOT EXISTS - FOR (e:ExtractedEntity) ON (e.config_id) - """) - print("✓ Created: entity_config_id_index") - - # MemorySummary.config_id index - await connector.execute_query(""" - CREATE INDEX summary_config_id_index IF NOT EXISTS - FOR (m:MemorySummary) ON (m.config_id) - """) - print("✓ Created: summary_config_id_index") - - print("\nConfig ID indexes created successfully!") - print("These indexes enable fast filtering by configuration ID.") - - except Exception as e: - print(f"✗ Error creating config_id indexes: {e}") - finally: - await connector.close() - - async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. - Ensures concurrent MERGE operations remain safe and prevents duplicates. """ connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Unique Constraints") - print("=" * 70) - + try: # Dialogue.id unique await connector.execute_query( """ @@ -234,8 +136,7 @@ async def create_unique_constraints(): FOR (d:Dialogue) REQUIRE d.id IS UNIQUE """ ) - print("✓ Created: dialog_id_unique") - + # Statement.id unique await connector.execute_query( """ @@ -243,8 +144,7 @@ async def create_unique_constraints(): FOR (s:Statement) REQUIRE s.id IS UNIQUE """ ) - print("✓ Created: statement_id_unique") - + # Chunk.id unique await connector.execute_query( """ @@ -252,112 +152,13 @@ async def create_unique_constraints(): FOR (c:Chunk) REQUIRE c.id IS UNIQUE """ ) - print("✓ Created: chunk_id_unique") - - print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.") - except Exception as e: - print(f"✗ Error creating unique constraints: {e}") + finally: await connector.close() - - async def create_all_indexes(): """Create all indexes and constraints in one go.""" - print("\n" + "=" * 70) - print("Neo4j Index & Constraint Setup") - print("=" * 70) - print("This will create:") - print(" 1. Full-text indexes (for keyword/BM25 search)") - print(" 2. Vector indexes (for embedding similarity search)") - print(" 3. Config ID indexes (for configuration isolation)") - print(" 4. Unique constraints (for data integrity)") - print("=" * 70) - await create_fulltext_indexes() await create_vector_indexes() - await create_config_id_indexes() await create_unique_constraints() - - print("\n" + "=" * 70) print("✓ All indexes and constraints created successfully!") - print("=" * 70) - print("\nTo verify, run in Neo4j Browser:") - print(" SHOW INDEXES") - print(" SHOW CONSTRAINTS") - print() - - -async def check_indexes(): - """Check what indexes currently exist.""" - connector = Neo4jConnector() - - try: - print("\n" + "=" * 70) - print("Checking Existing Indexes") - print("=" * 70) - query = "SHOW INDEXES" - result = await connector.execute_query(query) - - fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT'] - vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR'] - range_indexes = [idx for idx in result if idx.get('type') == 'RANGE'] - - print(f"\nFull-text indexes: {len(fulltext_indexes)}") - for idx in fulltext_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nVector indexes: {len(vector_indexes)}") - for idx in vector_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nRange indexes (including config_id): {len(range_indexes)}") - for idx in range_indexes: - print(f" ✓ {idx.get('name')}") - - if not vector_indexes: - print("\n⚠️ WARNING: No vector indexes found!") - print(" Embedding search will be VERY SLOW (~1.4s)") - print(" Run: python create_indexes.py") - - # Check for config_id indexes - config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')] - if len(config_id_indexes) < 4: - print("\n⚠️ WARNING: Not all config_id indexes found!") - print(f" Expected 4, found {len(config_id_indexes)}") - print(" Run: python create_indexes.py config_id") - - print("=" * 70) - - finally: - await connector.close() - - -if __name__ == "__main__": - import asyncio - import sys - - if len(sys.argv) > 1: - command = sys.argv[1] - if command == "check": - asyncio.run(check_indexes()) - elif command == "fulltext": - asyncio.run(create_fulltext_indexes()) - elif command == "vector": - asyncio.run(create_vector_indexes()) - elif command == "config_id": - asyncio.run(create_config_id_indexes()) - elif command == "constraints": - asyncio.run(create_unique_constraints()) - else: - print(f"Unknown command: {command}") - print("\nUsage:") - print(" python create_indexes.py # Create all indexes") - print(" python create_indexes.py check # Check existing indexes") - print(" python create_indexes.py fulltext # Create only full-text indexes") - print(" python create_indexes.py vector # Create only vector indexes") - print(" python create_indexes.py config_id # Create only config_id indexes") - print(" python create_indexes.py constraints # Create only constraints") - else: - asyncio.run(create_all_indexes()) - diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c08f9d0e..26ffe350 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WITH e, score -UNION -MATCH (e:ExtractedEntity) -WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) - AND e.aliases IS NOT NULL - AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) -WITH e, +WITH collect({entity: e, score: score}) AS fulltextResults + +OPTIONAL MATCH (ae:ExtractedEntity) +WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id) + AND ae.aliases IS NOT NULL + AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q)) +WITH fulltextResults, collect(ae) AS aliasEntities + +UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score: CASE - WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 - WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 ELSE 0.8 - END AS score + END +}]) AS row +WITH row.entity AS e, row.score AS score WITH DISTINCT e, MAX(score) AS score OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index 3f8919aa..af4449e5 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -158,22 +158,26 @@ class UserRepository: raise def get_users_by_tenant( - self, - tenant_id: uuid.UUID, - skip: int = 0, + self, + tenant_id: uuid.UUID, + skip: int = 0, limit: int = 100, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> List[User]: """获取租户下的用户列表""" db_logger.debug(f"查询租户用户: tenant_id={tenant_id}") - + try: query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -181,7 +185,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + users = query.offset(skip).limit(limit).all() db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}") return users @@ -190,18 +194,22 @@ class UserRepository: raise def count_users_by_tenant( - self, + self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" try: query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -209,7 +217,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + return query.scalar() except Exception as e: db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}") diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index e34945eb..f1e9132f 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel): # 记忆配置 memory: MemoryConfig = Field( - default_factory=lambda: MemoryConfig(enabled=True), + default_factory=lambda: MemoryConfig(enabled=False), description="对话历史记忆配置" ) diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 711b6de9..bfcf6337 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel): last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)") +class PageInfo(BaseModel): + """分页信息模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + page: int = Field(..., description="当前页码(从1开始)") + pagesize: int = Field(..., description="每页数量") + total: int = Field(..., description="总记录数") + hasnext: bool = Field(..., description="是否有下一页") + + +class PendingNodesResponse(BaseModel): + """待遗忘节点列表响应模型(独立分页接口)""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表") + page: PageInfo = Field(..., description="分页信息") + + class ForgettingStatsResponse(BaseModel): """遗忘引擎统计信息响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel): node_distribution: Dict[str, int] = Field(..., description="节点类型分布") recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") - pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py index 6b880696..aa9ac256 100644 --- a/api/app/schemas/user_schema.py +++ b/api/app/schemas/user_schema.py @@ -1,6 +1,6 @@ from dataclasses import field from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict -from typing import Optional +from typing import Optional, List import datetime import uuid @@ -20,6 +20,7 @@ class UserCreate(UserBase): class UserUpdate(BaseModel): username: Optional[str] = None email: Optional[EmailStr] = None + phone: Optional[str] = None is_active: Optional[bool] = None is_superuser: Optional[bool] = None @@ -85,6 +86,8 @@ class User(UserBase): current_workspace_name: Optional[str] = None role: Optional[WorkspaceRole] = None preferred_language: Optional[str] = "zh" # 用户语言偏好 + phone: Optional[str] = None # 用户电话 + permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制 # 将 datetime 转换为毫秒时间戳 @validator("created_at", pre=True) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 17c2f98c..df81568f 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -141,13 +141,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( @@ -378,13 +378,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( diff --git a/api/app/services/app_log_service.py b/api/app/services/app_log_service.py new file mode 100644 index 00000000..856045d1 --- /dev/null +++ b/api/app/services/app_log_service.py @@ -0,0 +1,128 @@ +"""应用日志服务层""" +import uuid +from typing import Optional, Tuple +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.models.conversation_model import Conversation, Message +from app.repositories.conversation_repository import ConversationRepository, MessageRepository + +logger = get_business_logger() + + +class AppLogService: + """应用日志服务""" + + def __init__(self, db: Session): + self.db = db + self.conversation_repository = ConversationRepository(db) + self.message_repository = MessageRepository(db) + + def list_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + page: int = 1, + pagesize: int = 20, + is_draft: Optional[bool] = None, + ) -> Tuple[list[Conversation], int]: + """ + 查询应用日志会话列表 + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + page: 页码(从 1 开始) + pagesize: 每页数量 + is_draft: 是否草稿会话(None 表示不过滤) + + Returns: + Tuple[list[Conversation], int]: (会话列表,总数) + """ + logger.info( + "查询应用日志会话列表", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "page": page, + "pagesize": pagesize, + "is_draft": is_draft + } + ) + + # 使用 Repository 查询 + conversations, total = self.conversation_repository.list_app_conversations( + app_id=app_id, + workspace_id=workspace_id, + is_draft=is_draft, + page=page, + pagesize=pagesize + ) + + logger.info( + "查询应用日志会话列表成功", + extra={ + "app_id": str(app_id), + "total": total, + "returned": len(conversations) + } + ) + + return conversations, total + + def get_conversation_detail( + self, + app_id: uuid.UUID, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询会话详情(包含消息) + + Args: + app_id: 应用 ID + conversation_id: 会话 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 包含消息的会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info( + "查询应用日志会话详情", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "workspace_id": str(workspace_id) + } + ) + + # 查询会话 + conversation = self.conversation_repository.get_conversation_for_app_log( + conversation_id=conversation_id, + app_id=app_id, + workspace_id=workspace_id + ) + + # 查询消息(按时间正序) + messages = self.message_repository.get_messages_by_conversation( + conversation_id=conversation_id + ) + + # 将消息附加到会话对象 + conversation.messages = messages + + logger.info( + "查询应用日志会话详情成功", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + + return conversation diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 4dcabff8..377f9479 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1084,7 +1084,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned exists = self.db.query( @@ -1096,7 +1095,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned @@ -1684,15 +1682,15 @@ class AppService: return config.config_id - def _update_endusers_memory_config_by_workspace( + def _update_endusers_memory_config_by_app( self, - workspace_id: uuid.UUID, + app_id: uuid.UUID, memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id Args: - workspace_id: 工作空间ID + app_id: 应用ID memory_config_id: 新的记忆配置ID Returns: @@ -1701,8 +1699,8 @@ class AppService: from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(self.db) - updated_count = repo.batch_update_memory_config_id_by_workspace( - workspace_id=workspace_id, + updated_count = repo.batch_update_memory_config_id_by_app( + app_id=app_id, memory_config_id=memory_config_id ) @@ -1753,12 +1751,16 @@ class AppService: miss_params = [] if agent_cfg.default_model_config_id is None: - miss_params.append("model config") + miss_params.append("模型配置") if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"): - miss_params.append("memory config") + miss_params.append("记忆配置") if miss_params: - raise BusinessException(f"{', '.join(miss_params)} is required") + raise BusinessException( + f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。", + BizCode.CONFIG_MISSING, + context={"missing_params": miss_params}, + ) config = { "system_prompt": agent_cfg.system_prompt, @@ -1877,8 +1879,8 @@ class AppService: if memory_config_id: app = self.db.query(App).filter(App.id == app_id).first() if app: - updated_count = self._update_endusers_memory_config_by_workspace( - app.workspace_id, memory_config_id + updated_count = self._update_endusers_memory_config_by_app( + app_id, memory_config_id ) logger.info( f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, " @@ -2014,7 +2016,7 @@ class AppService: if memory_config_id: - updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id) + updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id) logger.info( f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " f"memory_config_id={memory_config_id}, updated_count={updated_count}" diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 014d96b7..bd7f7496 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -214,7 +214,7 @@ class ConversationService: conversation.message_count += 1 - if conversation.message_count == 1 and role == "user": + if conversation.message_count <= 2 and role == "user": conversation.title = ( content[:50] + ("..." if len(content) > 50 else "") ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index aef54847..4b503f2b 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -448,15 +448,16 @@ class AgentRunService: features_config: Dict[str, Any], is_new_conversation: bool, variables: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + ) -> tuple[Any, Any]: """首轮对话时返回开场白文本(支持变量替换),否则返回 None""" if not is_new_conversation: - return None + return None, None opening = features_config.get("opening_statement", {}) if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): - return None + return None, None statement = opening["statement"] + suggested_questions = opening["suggested_questions"] # 如果有变量,进行替换(仅支持 {{var_name}} 格式) if variables: @@ -464,7 +465,7 @@ class AgentRunService: placeholder = f"{{{{{var_name}}}}}" statement = statement.replace(placeholder, str(var_value)) - return statement + return statement, suggested_questions @staticmethod def _filter_citations( @@ -598,13 +599,16 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -839,14 +843,17 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, sub_agent=sub_agent, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -1050,7 +1057,8 @@ class AgentRunService: workspace_id: uuid.UUID, user_id: Optional[str], sub_agent: bool = False, - opening_statement: Optional[str] = None + opening_statement: Optional[str] = None, + suggested_questions: Optional[List[str]] = None ) -> str: """确保会话存在(创建或验证) @@ -1061,6 +1069,7 @@ class AgentRunService: user_id: 用户ID sub_agent: 是否为子代理 opening_statement: 开场白(新会话时作为第一条消息写入) + suggested_questions: 预设问题列表 Returns: str: 会话ID @@ -1104,7 +1113,7 @@ class AgentRunService: conversation_id=uuid.UUID(new_conv_id), role="assistant", content=opening_statement, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) logger.debug(f"已保存开场白到会话 {new_conv_id}") diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 11118571..2d91f025 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -204,30 +204,35 @@ class MemoryForgetService: end_user_id: str, forgetting_threshold: float, min_days_since_access: int, - limit: int = 20 - ) -> list[Dict[str, Any]]: + page: Optional[int] = None, + pagesize: Optional[int] = None + ) -> Dict[str, Any]: """ 获取待遗忘节点列表 - - 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) - + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。 + Args: connector: Neo4j 连接器 end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 - limit: 返回节点数量限制 - + page: 页码(可选,从1开始) + pagesize: 每页数量(可选) + Returns: - list: 待遗忘节点列表 + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息(分页时) """ from datetime import timedelta - + # 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区) min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') - - query = """ + + # 基础查询(用于获取总数) + count_query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) AND n.end_user_id = $end_user_id @@ -235,10 +240,22 @@ class MemoryForgetService: AND n.activation_value < $threshold AND n.last_access_time IS NOT NULL AND datetime(n.last_access_time) < datetime($min_access_time_str) - RETURN + RETURN count(n) as total + """ + + # 数据查询 + data_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + AND n.end_user_id = $end_user_id + AND n.activation_value IS NOT NULL + AND n.activation_value < $threshold + AND n.last_access_time IS NOT NULL + AND datetime(n.last_access_time) < datetime($min_access_time_str) + RETURN elementId(n) as node_id, labels(n)[0] as node_type, - CASE + CASE WHEN n:Statement THEN n.statement WHEN n:ExtractedEntity THEN n.name WHEN n:MemorySummary THEN n.content @@ -247,18 +264,32 @@ class MemoryForgetService: n.activation_value as activation_value, n.last_access_time as last_access_time ORDER BY n.activation_value ASC - LIMIT $limit """ - + + # 如果启用分页,添加 SKIP 和 LIMIT + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + data_query += " SKIP $skip LIMIT $limit" + params = { 'end_user_id': end_user_id, 'threshold': forgetting_threshold, - 'min_access_time_str': min_access_time_str, - 'limit': limit + 'min_access_time_str': min_access_time_str } - - results = await connector.execute_query(query, **params) - + + # 获取总数(分页时需要) + total = 0 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + count_results = await connector.execute_query(count_query, **params) + if count_results: + total = count_results[0]['total'] + + # 添加分页参数 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + params['skip'] = (page - 1) * pagesize + params['limit'] = pagesize + + results = await connector.execute_query(data_query, **params) + pending_nodes = [] for result in results: # 将节点类型标签转换为小写 @@ -267,7 +298,7 @@ class MemoryForgetService: node_type_label = 'entity' elif node_type_label == 'memorysummary': node_type_label = 'summary' - + # 将 Neo4j DateTime 对象转换为时间戳(毫秒) last_access_time = result['last_access_time'] last_access_dt = convert_neo4j_datetime_to_python(last_access_time) @@ -278,7 +309,7 @@ class MemoryForgetService: last_access_timestamp = int(last_access_dt.timestamp() * 1000) else: last_access_timestamp = 0 - + pending_nodes.append({ 'node_id': str(result['node_id']), 'node_type': node_type_label, @@ -286,8 +317,20 @@ class MemoryForgetService: 'activation_value': result['activation_value'], 'last_access_time': last_access_timestamp }) - - return pending_nodes + + # 构建返回结果 + result: Dict[str, Any] = {'items': pending_nodes} + + # 如果启用分页,添加分页信息 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + result['page'] = { + 'page': page, + 'pagesize': pagesize, + 'total': total, + 'hasnext': (page * pagesize) < total + } + + return result async def trigger_forgetting_cycle( self, @@ -636,7 +679,7 @@ class MemoryForgetService: api_logger.error(f"获取历史趋势数据失败: {str(e)}") # 失败时返回空列表,不影响主流程 - # 获取待遗忘节点列表(前20个满足遗忘条件的节点) + # 获取待遗忘节点列表 pending_nodes = [] try: if end_user_id: @@ -652,8 +695,7 @@ class MemoryForgetService: connector=connector, end_user_id=end_user_id, forgetting_threshold=forgetting_threshold, - min_days_since_access=int(min_days), - limit=20 + min_days_since_access=int(min_days) ) api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") @@ -661,24 +703,79 @@ class MemoryForgetService: except Exception as e: api_logger.error(f"获取待遗忘节点失败: {str(e)}") # 失败时返回空列表,不影响主流程 - - # 构建统计信息 + + # 构建统计信息(不包含 pending_nodes,已分离到独立接口) stats = { 'activation_metrics': activation_metrics, 'node_distribution': node_distribution, 'recent_trends': recent_trends, - 'pending_nodes': pending_nodes, 'timestamp': int(datetime.now().timestamp() * 1000) } - + api_logger.info( f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " - f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" + f"trend_days={len(recent_trends)}" ) - + return stats - + + async def get_pending_nodes( + self, + db: Session, + end_user_id: str, + config_id: Optional[UUID] = None, + page: int = 1, + pagesize: int = 10 + ) -> Dict[str, Any]: + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + + Args: + db: 数据库会话 + end_user_id: 组ID(必填) + config_id: 配置ID(可选,用于获取遗忘阈值) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + + Returns: + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息 + """ + # 获取遗忘引擎组件 + _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + + connector = forgetting_scheduler.connector + forgetting_threshold = config['forgetting_threshold'] + + # 验证 min_days_since_access 配置值 + min_days = config.get('min_days_since_access') + if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: + api_logger.warning( + f"min_days_since_access 配置无效: {min_days}, 使用默认值 7" + ) + min_days = 7 + + # 调用内部方法获取分页数据 + pending_nodes_result = await self._get_pending_forgetting_nodes( + connector=connector, + end_user_id=end_user_id, + forgetting_threshold=forgetting_threshold, + min_days_since_access=int(min_days), + page=page, + pagesize=pagesize + ) + + api_logger.info( + f"成功获取待遗忘节点列表: end_user_id={end_user_id}, " + f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}" + ) + + return pending_nodes_result + async def get_forgetting_curve( self, db: Session, diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 120cccb7..2e9f809a 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -12,6 +12,9 @@ import base64 import csv import io import json +import re +import olefile +import struct import zipfile from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -602,31 +605,75 @@ class MultimodalService: try: word_file = io.BytesIO(file_content) doc = Document(word_file) - return '\n'.join(p.text for p in doc.paragraphs) + text_lines = [] + for p in doc.paragraphs: + text = p.text.strip() + if text: + text_lines.append(text) + + for table in doc.tables: + for row in table.rows: + for cell in row.cells: + text = cell.text.strip() + if text: + text_lines.append(text) + + full_text = "\n".join(text_lines) + return full_text.strip() or "[docx 文件无文本内容]" except Exception as e: - logger.error(f"提取 docx 文本失败: {e}") + logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True) return f"[docx 提取失败: {str(e)}]" - # 旧版 .doc(OLE2 格式) + # 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table try: - import olefile ole = olefile.OleFileIO(io.BytesIO(file_content)) - if not ole.exists('WordDocument'): - return "[doc 提取失败: 未找到 WordDocument 流]" - # 读取 WordDocument 流,提取可见 ASCII/Unicode 文本 - stream = ole.openstream('WordDocument').read() - # Word Binary Format: 文本在流中以 UTF-16-LE 编码存储 - # 简单提取:过滤出可打印字符段 - try: - text = stream.decode('utf-16-le', errors='ignore') - except Exception: - text = stream.decode('latin-1', errors='ignore') - # 过滤控制字符,保留可打印内容 - import re - text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) - text = re.sub(r' +', ' ', text).strip() + word_stream = ole.openstream('WordDocument').read() + + # FIB offset 0xA bit9 决定使用 0Table 还是 1Table + fib_flags = struct.unpack_from(' List[UserModel]: """获取租户下的用户列表""" @@ -155,6 +156,7 @@ class TenantService: skip=skip, limit=limit, is_active=is_active, + is_superuser=is_superuser, search=search ) @@ -162,12 +164,14 @@ class TenantService: self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" return self.user_repo.count_users_by_tenant( tenant_id=tenant_id, is_active=is_active, + is_superuser=is_superuser, search=search ) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 942e01a0..ab51d922 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -472,6 +472,21 @@ class UserMemoryService: # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} + # 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中 + _user_placeholder_names = {'用户', '我', 'User', 'I'} + + # 过滤 other_name:不允许设置为占位名称 + if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names: + logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name") + del update_data['other_name'] + + # 过滤 aliases:移除占位名称和非字符串值 + if 'aliases' in update_data and update_data['aliases']: + update_data['aliases'] = [ + a for a in update_data['aliases'] + if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names + ] + # 检查是否更新了 aliases 字段 aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases diff --git a/api/app/tasks.py b/api/app/tasks.py index d5f09a29..72421a5f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,4 @@ import asyncio -import hashlib import os import re import shutil @@ -38,12 +37,10 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema -from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService -from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisLock +from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) @@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: def set_asyncio_event_loop(): - """Set the asyncio event loop for the current thread.""" + """Ensure an open asyncio event loop exists for the current thread. + + Reuses the existing event loop if one is available and still open. + Creates and installs a new event loop only when the current one is + closed or missing (e.g. after ``_shutdown_loop_gracefully``). + """ try: loop = asyncio.get_event_loop() if loop.is_closed(): @@ -116,6 +118,30 @@ def set_asyncio_event_loop(): return loop +def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop): + """Gracefully shutdown pending async generators and tasks on the event loop. + + This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + by giving pending aclose() coroutines a chance to run before the loop is discarded. + + Note: This only tears down the given loop. Callers that need a fresh event + loop afterwards should use ``set_asyncio_event_loop()`` explicitly. + """ + try: + # Cancel and collect all remaining tasks + all_tasks = asyncio.all_tasks(loop) + if all_tasks: + for task in all_tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True)) + # Shutdown async generators (triggers __aclose__ on httpx clients etc.) + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception: + pass + finally: + loop.close() + + @celery_app.task(name="tasks.process_item") def process_item(item: dict): """ @@ -1148,8 +1174,28 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result + redis_client = get_sync_redis_client() + lock = None + if redis_client is not None: + lock = RedisFairLock( + key=f"memory_write:{end_user_id}", + redis_client=redis_client, + expire=600, + timeout=3600, + auto_renewal=True, + ) + if not lock.acquire(): + logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") + return { + "status": "SKIPPED", + "error": "acquire lock timeout", + "end_user_id": end_user_id, + "config_id": str(config_id), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + try: - # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1158,7 +1204,6 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1199,6 +1244,15 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } + finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") + # Gracefully shutdown the event loop to prevent + # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + _shutdown_loop_gracefully(loop) # unused task @@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index 99f62d84..a86ba46e 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,6 +1,7 @@ import redis import uuid import time +import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -10,45 +11,136 @@ else end """ +RENEW_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) +else + return 0 +end +""" -class RedisLock: +CLEANUP_DEAD_HEAD_SCRIPT = """ +local queue_key = KEYS[1] +local lock_key = KEYS[2] + +local first = redis.call("lindex", queue_key, 0) +if not first then + return 0 +end + +if redis.call("exists", lock_key) == 1 then + return 0 +end + +redis.call("lpop", queue_key) +return 1 +""" + +SAFE_RELEASE_QUEUE_SCRIPT = """ +local queue_key = KEYS[1] +local value = ARGV[1] + +local first = redis.call("lindex", queue_key, 0) +if first == value then + redis.call("lpop", queue_key) + return 1 +end +return 0 +""" + + +def _ensure_str(val): + """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" + if val is None: + return None + if isinstance(val, bytes): + return val.decode("utf-8") + return str(val) + + +class RedisFairLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 60, - retry_interval: float = 0.1, - timeout: float = 30 - + expire: int = 30, + retry_interval: float = 0.05, + timeout: float = 600, + auto_renewal: bool = True ): self.key = key - self.expire = expire + self.queue_key = f"{key}:queue" self.value = str(uuid.uuid4()) - self._locked = False + self.expire = expire self.retry_interval = retry_interval self.timeout = timeout - self.redis_client = redis_client + self.redis = redis_client + self._locked = False + self.auto_renewal = auto_renewal + self._renew_thread = None + self._stop_renew = threading.Event() - def acquire(self) -> bool: + def acquire(self): start = time.time() + + self.redis.rpush(self.queue_key, self.value) + while True: - ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) - if ok: - self._locked = True - return True - if time.time() - start >= self.timeout: + first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + + if first == self.value: + ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) + if ok: + self._locked = True + + if self.auto_renewal: + self._start_renewal() + return True + + if first: + self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + + if time.time() - start > self.timeout: + self.redis.lrem(self.queue_key, 0, self.value) return False + time.sleep(self.retry_interval) + def _renewal_loop(self): + while not self._stop_renew.is_set(): + time.sleep(self.expire / 3) + if self._stop_renew.is_set(): + break + + self.redis.eval( + RENEW_SCRIPT, + 1, + self.key, + self.value, + str(self.expire) + ) + + def _start_renewal(self): + self._stop_renew = threading.Event() + self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) + self._renew_thread.start() + + def _stop_renewal(self): + self._stop_renew.set() + if self._renew_thread: + self._renew_thread.join(timeout=1) + def release(self): if not self._locked: return - self.redis_client.eval( - UNLOCK_SCRIPT, - 1, - self.key, - self.value - ) + + if self.auto_renewal: + self._stop_renewal() + + self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) + + self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._locked = False def __enter__(self): @@ -59,3 +151,4 @@ class RedisLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + diff --git a/api/migrations/versions/4e89970f9e7c_202603271515.py b/api/migrations/versions/4e89970f9e7c_202603271515.py new file mode 100644 index 00000000..f37c4b27 --- /dev/null +++ b/api/migrations/versions/4e89970f9e7c_202603271515.py @@ -0,0 +1,30 @@ +"""202603271515 + +Revision ID: 4e89970f9e7c +Revises: 6b8a461148ff +Create Date: 2026-03-27 15:12:27.518344 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4e89970f9e7c' +down_revision: Union[str, None] = '6b8a461148ff' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'phone') + # ### end Alembic commands ### diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index ee71bea8..077cdf53 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => { export const getForgetStats = (end_user_id: string) => { return request.get(`/memory/forget-memory/stats`, { end_user_id }) } +// 获取带遗忘节点列表 +export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes' // Implicit Memory - Preferences export const getImplicitPreferences = (end_user_id: string) => { return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index ddb25838..0276916f 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -37,11 +37,11 @@ const ChatContent: FC = ({ const prevDataLengthRef = useRef(data.length); const isScrolledToBottomRef = useRef(true); const audioRef = useRef(null) - const [playingIndex, setPlayingIndex] = useState(null) + const [playingIndex, setPlayingIndex] = useState(null) - const handlePlay = (index: number, audio_url: string, audio_status?: string) => { - if (audio_status !== 'completed' && !audio_status) return - if (playingIndex === index) { + const handlePlay = (audio_url: string, audio_status?: string) => { + if (audio_status !== 'completed' && typeof audio_status === 'string') return + if (playingIndex === audio_url) { audioRef.current?.pause() setPlayingIndex(null) return @@ -52,7 +52,7 @@ const ChatContent: FC = ({ const audio = new Audio(audio_url) audioRef.current = audio audio.play() - setPlayingIndex(index) + setPlayingIndex(audio_url) audio.onended = () => setPlayingIndex(null) } @@ -79,12 +79,16 @@ const ChatContent: FC = ({ } }; }, []); - + // Auto-scroll to bottom when data changes to show latest messages // When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom // When data array length changes, auto-scroll to bottom // If already scrolled to bottom, will auto-scroll to bottom useEffect(() => { + if (playingIndex && !data.some(item => item.meta_data?.audio_url === playingIndex)) { + audioRef.current?.pause() + setPlayingIndex(null) + } setTimeout(() => { if (scrollContainerRef.current) { // Auto-scroll if data length changed OR user is currently at bottom @@ -204,16 +208,16 @@ const ChatContent: FC = ({ {item.meta_data?.audio_url && <> - {playingIndex !== index && item.meta_data?.audio_status === 'pending' + {playingIndex !== item.meta_data?.audio_url && item.meta_data?.audio_status === 'pending' ? - : playingIndex !== index + : playingIndex !== item.meta_data?.audio_url ? handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> + })} onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> :
handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} + onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> } diff --git a/web/src/components/Empty/BodyWrapper.tsx b/web/src/components/Empty/BodyWrapper.tsx index 067b743c..5d23b55c 100644 --- a/web/src/components/Empty/BodyWrapper.tsx +++ b/web/src/components/Empty/BodyWrapper.tsx @@ -24,16 +24,17 @@ interface BodyWrapperProps { /** Whether to show loading state */ loading?: boolean /** Whether the content is empty */ - empty: boolean + empty: boolean; + className?: string; } -const BodyWrapper: FC = ({ children, loading = false, empty }) => { +const BodyWrapper: FC = ({ children, loading = false, empty, className = 'rb:max-h-[calc(100%-48px)]!' }) => { // Show loading spinner while data is being fetched if (loading) { - return + return } // Show empty state when no data is available if (!loading && empty) { - return + return } // Render actual content when data is loaded and available return children diff --git a/web/src/components/SiderMenu/index.tsx b/web/src/components/SiderMenu/index.tsx index 3bd0cea3..21f7fd36 100644 --- a/web/src/components/SiderMenu/index.tsx +++ b/web/src/components/SiderMenu/index.tsx @@ -128,6 +128,7 @@ const Menu: FC<{ /** Filter menus based on user role and source */ useEffect(() => { + if (!user) return let menuList: MenuItem[] = [] if (user.role === 'member' && source === 'space') { @@ -136,7 +137,7 @@ const Menu: FC<{ menuList = allMenus[source] || [] } - const noAuthList = ['user', 'pricing'].filter(vo => !user.permissions?.includes(vo) && !user.permissions?.includes('all')) + const noAuthList = ['user', 'pricing'].filter(vo => (Array.isArray(user.permissions) && !user.permissions?.includes(vo) && !user.permissions?.includes('all')) || !Array.isArray(user.permissions)) if (noAuthList && !noAuthList?.includes('all')) { const filterMenus = (list: MenuItem[]): MenuItem[] =>{ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 2975796a..47f68705 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1509,7 +1509,7 @@ export const en = { EPISODIC_MEMORY: 'Episodic Memory', FORGET_MEMORY: 'Forget Memory', - endUserProfile: 'Profile', + endUserProfile: 'Permanent Memory', editEndUserProfile: 'Edit', other_name: 'Name', position: 'Position', @@ -1828,6 +1828,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re memoryTipTitle: 'Are you sure you want to enable conversation memory? Conversations will be saved to the memory store.', stopAudioRecorder: 'Stop Recording', startAudioRecorder: 'Start Recording', + citations: 'Citations', }, login: { title: 'Red Bear Memory Science', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 3edd84e3..8c22a506 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1507,7 +1507,7 @@ export const zh = { EPISODIC_MEMORY: '情景记忆', FORGET_MEMORY: '遗忘记忆', - endUserProfile: '核心档案', + endUserProfile: '永久记忆', editEndUserProfile: '编辑', other_name: '名称', position: '职位', diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 684a4ba6..c6e7f359 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -377,9 +377,10 @@ body { .ant-input-filled, .ant-select-filled:not(.ant-select-customize-input) .ant-select-selector { background-color: #FFFFFF; + border-color: #FFFFFF; } .ant-input-filled:hover, -.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector { +.ant-select-filled:not(.ant-select-disabled):not(.ant-select-customize-input):not(.ant-pagination-size-changer):hover .ant-select-selector { background-color: #FFFFFF; border-color: #171719; } @@ -402,7 +403,7 @@ body { color: #FF5D34; } -.spin.ant-spin-nested-loading .ant-spin-container::after { +.spin .ant-spin-nested-loading .ant-spin-container::after { background: transparent; } .upload-block, diff --git a/web/src/views/ApplicationConfig/Logs.tsx b/web/src/views/ApplicationConfig/Logs.tsx index 88fa2607..cf56059c 100644 --- a/web/src/views/ApplicationConfig/Logs.tsx +++ b/web/src/views/ApplicationConfig/Logs.tsx @@ -34,16 +34,16 @@ const Statistics: FC = () => { className: 'rb:text-[#212332]' }, { - title: t('user.createTime'), + title: t('application.created_at'), dataIndex: 'created_at', key: 'created_at', render: (createdAt: string) => formatDateTime(createdAt, 'YYYY-MM-DD HH:mm:ss'), }, { - title: t('user.lastLoginTime'), - dataIndex: 'last_login_at', - key: 'last_login_at', - render: (lastLoginAt: string) => lastLoginAt ? formatDateTime(lastLoginAt, 'YYYY-MM-DD HH:mm:ss') : '-', + title: t('common.updated_at'), + dataIndex: 'updated_at', + key: 'updated_at', + render: (updatedAt: string) => updatedAt ? formatDateTime(updatedAt, 'YYYY-MM-DD HH:mm:ss') : '-', }, { title: t('common.operation'), diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index 8e6fc875..bebf6ebd 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -220,31 +220,31 @@ const ConfigHeader: FC = ({ />
diff --git a/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx b/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx index dba03ab2..3fb7bc93 100644 --- a/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx +++ b/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx @@ -49,7 +49,7 @@ const FeaturesConfig: FC = ({ ?
diff --git a/web/src/views/ApplicationManagement/index.tsx b/web/src/views/ApplicationManagement/index.tsx index 4d49635c..3b444a3d 100644 --- a/web/src/views/ApplicationManagement/index.tsx +++ b/web/src/views/ApplicationManagement/index.tsx @@ -216,7 +216,7 @@ const ApplicationManagement: React.FC = () => { 'rb:text-[#155EEF]': key === 'type', })}> {key === 'source' && item.is_shared - ? t('application.shared') + ? item.source_workspace_name : key === 'source' && !item.is_shared ? t('application.configuration') : key === 'created_at' diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 80394317..d4d25070 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -64,6 +64,13 @@ const Conversation: FC = () => { const [config, setConfig] = useState>({}) const [audioStatusMap, setAudioStatusMap] = useState>({}) + useEffect(() => { + return () => { + audioPollingRef.current.forEach((timer) => clearInterval(timer)) + audioPollingRef.current.clear() + } + }, []) + useEffect(() => { const shareToken = localStorage.getItem(`shareToken_${token}`) setShareToken(shareToken) @@ -144,13 +151,29 @@ const Conversation: FC = () => { } useEffect(() => { - audioPollingRef.current.forEach((timer) => clearInterval(timer)) - audioPollingRef.current.clear() if (conversation_id) { getConversationDetail(token as string, conversation_id) .then(res => { const response = res as { messages: ChatItem[] } - setChatList(response?.messages || []) + const messages = response?.messages || [] + const historyAudioUrls = new Set(messages.map(m => m.meta_data?.audio_url).filter(Boolean)) + audioPollingRef.current.forEach((timer, key) => { + if (!historyAudioUrls.has(key)) { + clearInterval(timer) + audioPollingRef.current.delete(key) + } + }) + messages.forEach(msg => { + if (msg.role === 'assistant' && msg.meta_data?.audio_url && msg.meta_data?.audio_status === 'pending') { + startAudioPolling(msg.meta_data.audio_url, msg.meta_data.audio_url) + } + }) + setChatList(messages.map(msg => { + if (msg.role === 'assistant' && msg.meta_data?.audio_url && audioPollingRef.current.has(msg.meta_data.audio_url)) { + return { ...msg, meta_data: { ...msg.meta_data, audio_status: 'pending' } } + } + return msg + })) }) } else { if (features?.opening_statement?.statement) { @@ -228,6 +251,28 @@ const Conversation: FC = () => { })) }, [audioStatusMap, chatList.length]) + const startAudioPolling = (audioUrl: string, idToPoll: string) => { + if (audioPollingRef.current.has(idToPoll)) return + const fileId = audioUrl.split('/').pop() + if (!fileId) return + const timer = setInterval(() => { + getFileStatusById(fileId) + .then(res => { + const { status } = res as { status: string } + if (status && status !== 'pending') { + setAudioStatusMap(prev => ({ ...prev, [idToPoll]: status })) + clearInterval(audioPollingRef.current.get(idToPoll)) + audioPollingRef.current.delete(idToPoll) + } + }) + .catch(() => { + clearInterval(audioPollingRef.current.get(idToPoll)) + audioPollingRef.current.delete(idToPoll) + }) + }, 2000) + audioPollingRef.current.set(idToPoll, timer) + } + /** Send message and handle streaming response */ const handleSend = (msg?: string) => { if (!token || !shareToken) return @@ -287,35 +332,8 @@ const Conversation: FC = () => { const { file_id } = item.data as { file_id?: string } const idToPoll = file_id || audio_url || '' const fileId = audio_url.split('/').pop() - if (fileId && idToPoll && !audioPollingRef.current.has(idToPoll)) { - - const timer = setInterval(() => { - getFileStatusById(fileId) - .then(res => { - const { status } = res as { status: string } - if (status && status !== 'pending') { - setAudioStatusMap(prev => ({ - ...prev, - [idToPoll]: status - })) - clearInterval(audioPollingRef.current.get(idToPoll)) - audioPollingRef.current.delete(idToPoll) - getHistory(true) - if (currentConversationId && currentConversationId !== conversation_id) { - setConversationId(currentConversationId) - } - } - }) - .catch(() => { - clearInterval(audioPollingRef.current.get(idToPoll)) - audioPollingRef.current.delete(idToPoll) - getHistory(true) - if (currentConversationId && currentConversationId !== conversation_id) { - setConversationId(currentConversationId) - } - }) - }, 2000) - audioPollingRef.current.set(idToPoll, timer) + if (fileId && idToPoll) { + startAudioPolling(audio_url, idToPoll) } } else { getHistory(true) @@ -327,6 +345,10 @@ const Conversation: FC = () => { updateAssistantMessage(content, audio_url, undefined, citations) } setLoading(false) + getHistory(true) + if (currentConversationId && currentConversationId !== conversation_id) { + setConversationId(currentConversationId) + } break } }) diff --git a/web/src/views/InviteRegister/index.tsx b/web/src/views/InviteRegister/index.tsx index 42cffff1..72ae55e5 100644 --- a/web/src/views/InviteRegister/index.tsx +++ b/web/src/views/InviteRegister/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:37:12 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-04 10:05:39 + * @Last Modified time: 2026-03-27 22:22:18 */ /** * Invite Register Page @@ -144,7 +144,7 @@ const InviteRegister: React.FC = () => { }).then((res) => { const response = res as LoginInfo; updateLoginInfo(response); - navigate('/'); + navigate('/', { replace: true }); }).finally(() => { setLoading(false); }); diff --git a/web/src/views/Prompt/index.tsx b/web/src/views/Prompt/index.tsx index 095a8daa..13c09042 100644 --- a/web/src/views/Prompt/index.tsx +++ b/web/src/views/Prompt/index.tsx @@ -243,6 +243,7 @@ const Prompt: FC = () => { diff --git a/web/src/views/Prompt/pages/History.tsx b/web/src/views/Prompt/pages/History.tsx index 573b4a90..19c033ed 100644 --- a/web/src/views/Prompt/pages/History.tsx +++ b/web/src/views/Prompt/pages/History.tsx @@ -116,13 +116,13 @@ const History: React.FC = () => {
{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}
-
handleClick('detail', item)} >
-
handleClick('edit', item)} >
-
handleClick('delete', item)} >
diff --git a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx index 33da0b04..ccfbc14d 100644 --- a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx +++ b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx @@ -65,8 +65,8 @@ const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => voi }, [id]) if (loading) { - return - + return +
diff --git a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx index 2510aaa9..04391107 100644 --- a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx @@ -12,6 +12,7 @@ import { Row, Col, Progress, App, Table } from 'antd' import RbCard from '@/components/RbCard/Card' import { getForgetStats, + getForgetPendingNodesUrl, } from '@/api/memory' import type { ForgetData } from '../types' import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard' @@ -19,6 +20,7 @@ import RecentTrendsLineCard from '../components/RecentTrendsLineCard' import { formatDateTime } from '@/utils/format' import StatusTag from '@/components/StatusTag' import ForgetRefreshModal from '../components/ForgetRefreshModal'; +import RbTable from '@/components/Table' /** Maps node type keys to StatusTag colour presets for the pending-nodes table. */ const statusTagColors: Record = { @@ -191,7 +193,9 @@ const ForgetDetail = forwardRef((_props, ref) => { bodyClassName="rb:p-3! rb:py-0! rb:h-[calc(100%-54px)]" className="rb:h-full!" > - { render: (activation_value) => {activation_value} }, ]} - pagination={{ - pageSize: 5, - showQuickJumper: true, - className: 'rb:mt-5! rb:mb-5.75!' - }} className="table-header-has-bg" />