diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index aac2aa84..dfb63dad 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,6 +1,8 @@ import asyncio import json import logging +import os +import threading from typing import Dict, Any, Optional import redis.asyncio as redis @@ -21,6 +23,50 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) +_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}" + +# Thread-local storage for connection pools. +# Each thread (and each forked process) gets its own pool to avoid +# "Future attached to a different loop" errors in Celery --pool=threads +# and stale connections after fork in --pool=prefork. +_thread_local = threading.local() + + +def get_thread_safe_redis() -> redis.StrictRedis: + """Return a Redis client whose connection pool is bound to the current + thread, process **and** event loop. + + The pool is recreated when: + - The PID changes (fork, Celery --pool=prefork) + - The thread has no pool yet (Celery --pool=threads) + - The previously-cached event loop has been closed (Celery tasks call + ``_shutdown_loop_gracefully`` which closes the loop after each run) + """ + current_pid = os.getpid() + cached_loop = getattr(_thread_local, "loop", None) + loop_stale = cached_loop is not None and cached_loop.is_closed() + + if not hasattr(_thread_local, "pool") \ + or getattr(_thread_local, "pid", None) != current_pid \ + or loop_stale: + _thread_local.pid = current_pid + # Python 3.10+: get_event_loop() raises RuntimeError in threads + # where no loop has been set yet (e.g. Celery --pool=threads). + try: + _thread_local.loop = asyncio.get_event_loop() + except RuntimeError: + _thread_local.loop = None + _thread_local.pool = ConnectionPool.from_url( + _REDIS_URL, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=5, + health_check_interval=30, + ) + + return redis.StrictRedis(connection_pool=_thread_local.pool) + async def get_redis_connection(): """获取Redis连接""" @@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None): val = json.dumps(val, ensure_ascii=False) if expire is not None: - # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) else: - # 设置永久键值 await aio_redis.set(key, val) except Exception as e: logger.error(f"Redis set错误: {str(e)}") diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py index 6b162cdd..e0008353 100644 --- a/api/app/cache/memory/activity_stats_cache.py +++ b/api/app/cache/memory/activity_stats_cache.py @@ -10,7 +10,7 @@ import logging from typing import Optional, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ class ActivityStatsCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -90,7 +90,7 @@ class ActivityStatsCache: """ try: key = cls._get_key(workspace_id) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中活动统计缓存: {key}") @@ -116,7 +116,7 @@ class ActivityStatsCache: """ try: key = cls._get_key(workspace_id) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除活动统计缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py index 108e2a37..2881f06c 100644 --- a/api/app/cache/memory/interest_memory.py +++ b/api/app/cache/memory/interest_memory.py @@ -9,7 +9,7 @@ import logging from typing import Optional, List, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class InterestMemoryCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -86,7 +86,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中兴趣分布缓存: {key}") @@ -114,7 +114,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3ba9c3a9..74991bcf 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,7 +57,6 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, - api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -66,7 +65,7 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - - 当提供 api_key 参数时,查找该 API Key 关联的应用 + - search 参数支持:应用名称模糊搜索、API Key 精确搜索 """ from sqlalchemy import select as sa_select from app.models.api_key_model import ApiKey @@ -74,23 +73,34 @@ def list_apps( workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 - if api_key: - matched_id = db.execute( - sa_select(ApiKey.resource_id).where( - ApiKey.workspace_id == workspace_id, - ApiKey.api_key == api_key, - ApiKey.resource_id.isnot(None), - ) - ).scalar_one_or_none() - ids = str(matched_id) if matched_id else "" + # 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索 + if search: + search = search.strip() + # 尝试作为 API Key 精确匹配(API Key 通常较长) + if len(search) >= 10: + matched_id = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key == search, + ApiKey.resource_id.isnot(None), + ) + ).scalar_one_or_none() + if matched_id: + # 找到 API Key,直接返回关联的应用 + ids = str(matched_id) - # 当 ids 存在且不为 None 时,根据 ids 获取应用 + # 当 ids 存在时,根据 ids 获取应用(不分页) if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] - items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) - items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - return success(data=items) + if app_ids: + items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) + items = [service._convert_to_schema(app, workspace_id) for app in items_orm] + # 返回标准分页格式 + meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False) + return success(data=PageData(page=meta, items=items)) + # ids 为空时,返回空列表 + meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False) + return success(data=PageData(page=meta, items=[])) # 正常分页查询 items_orm, total = app_service.list_apps( diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index dfd10644..92b5becd 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -3,17 +3,16 @@ import uuid from typing import Optional from fastapi import APIRouter, Depends, Query -from sqlalchemy import select, desc, func from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user, cur_workspace_access_guard -from app.models.conversation_model import Conversation, Message -from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage +from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail from app.schemas.response_schema import PageData, PageMeta from app.services.app_service import AppService +from app.services.app_log_service import AppLogService router = APIRouter(prefix="/apps", tags=["App Logs"]) logger = get_business_logger() @@ -25,52 +24,35 @@ def list_app_logs( app_id: uuid.UUID, page: int = Query(1, ge=1), pagesize: int = Query(20, ge=1, le=100), - user_id: Optional[str] = None, is_draft: Optional[bool] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): """查看应用下所有会话记录(分页) - - 支持按 user_id 筛选 - 支持按 is_draft 筛选(草稿会话 / 发布会话) - 按最新更新时间倒序排列 + - 所有人(包括共享者和被共享者)都只能查看自己的会话记录 """ workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - stmt = select(Conversation).where( - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), + # 使用 Service 层查询 + log_service = AppLogService(db) + conversations, total = log_service.list_conversations( + app_id=app_id, + workspace_id=workspace_id, + page=page, + pagesize=pagesize, + is_draft=is_draft ) - if user_id: - stmt = stmt.where(Conversation.user_id == user_id) - - if is_draft is not None: - stmt = stmt.where(Conversation.is_draft == is_draft) - - total = int(db.execute( - select(func.count()).select_from(stmt.subquery()) - ).scalar_one()) - - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - - conversations = list(db.scalars(stmt).all()) - items = [AppLogConversation.model_validate(c) for c in conversations] meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) - logger.info( - "查询应用日志会话列表", - extra={"app_id": str(app_id), "total": total, "page": page} - ) - return success(data=PageData(page=meta, items=items)) @@ -86,44 +68,22 @@ def get_app_log_detail( - 返回会话基本信息 + 所有消息(按时间正序) - 消息 meta_data 包含模型名、token 用量等信息 + - 所有人(包括共享者和被共享者)都只能查看自己的会话详情 """ workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - # 查询会话(确保属于该应用和工作空间) - conversation = db.scalars( - select(Conversation).where( - Conversation.id == conversation_id, - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), - ) - ).first() - - if not conversation: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("会话", str(conversation_id)) - - # 查询消息(按时间正序) - messages = list(db.scalars( - select(Message) - .where(Message.conversation_id == conversation_id) - .order_by(Message.created_at) - ).all()) - - detail = AppLogConversationDetail.model_validate(conversation) - detail.messages = [AppLogMessage.model_validate(m) for m in messages] - - logger.info( - "查询应用日志会话详情", - extra={ - "app_id": str(app_id), - "conversation_id": str(conversation_id), - "message_count": len(messages) - } + # 使用 Service 层查询 + log_service = AppLogService(db) + conversation = log_service.get_conversation_detail( + app_id=app_id, + conversation_id=conversation_id, + workspace_id=workspace_id ) + detail = AppLogConversationDetail.model_validate(conversation) + return success(data=detail) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index fe4337d1..bedee987 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -1,3 +1,5 @@ +import asyncio +import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -47,64 +49,64 @@ def get_workspace_total_end_users( @router.get("/end_users", response_model=ApiResponse) async def get_workspace_end_users( + workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), + keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"), + page: int = Query(1, ge=1, description="页码,从1开始"), + pagesize: int = Query(10, ge=1, description="每页数量"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """ - 获取工作空间的宿主列表(高性能优化版本 v2) - - 优化策略: - 1. 批量查询 end_users(一次查询而非循环) - 2. 并发查询所有用户的记忆数量(Neo4j) - 3. RAG 模式使用批量查询(一次 SQL) - 4. 只返回必要字段减少数据传输 - 5. 添加短期缓存减少重复查询 - 6. 并发执行配置查询和记忆数量查询 - - 返回格式: - { - "end_user": {"id": "uuid", "other_name": "名称"}, - "memory_num": {"total": 数量}, - "memory_config": {"memory_config_id": "id", "memory_config_name": "名称"} - } + 获取工作空间的宿主列表(分页查询,支持模糊搜索) + + 返回工作空间下的宿主列表,支持分页查询和模糊搜索。 + 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含宿主列表和分页信息 """ - import asyncio - import json - from app.aioRedis import aio_redis_get, aio_redis_set - - workspace_id = current_user.current_workspace_id - - # 尝试从缓存获取(30秒缓存) - cache_key = f"end_users:workspace:{workspace_id}" - try: - cached_data = await aio_redis_get(cache_key) - if cached_data: - api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}") - return success(data=json.loads(cached_data), msg="宿主列表获取成功") - except Exception as e: - api_logger.warning(f"Redis 缓存读取失败: {str(e)}") - + # 如果未提供 workspace_id,使用当前用户的工作空间 + if workspace_id is None: + workspace_id = current_user.current_workspace_id # 获取当前空间类型 current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") - - # 获取 end_users(已优化为批量查询) - end_users = memory_dashboard_service.get_workspace_end_users( + api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}") + + # 获取分页的 end_users + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( db=db, workspace_id=workspace_id, - current_user=current_user + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword ) + + end_users = end_users_result.get("items", []) + total = end_users_result.get("total", 0) + if not end_users: - api_logger.info("工作空间下没有宿主") - # 缓存空结果,避免重复查询 - try: - await aio_redis_set(cache_key, json.dumps([]), expire=30) - except Exception as e: - api_logger.warning(f"Redis 缓存写入失败: {str(e)}") - return success(data=[], msg="宿主列表获取成功") - + api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}") + return success(data={ + "items": [], + "page": { + "page": page, + "pagesize": pagesize, + "total": total, + "hasnext": (page * pagesize) < total + } + }, msg="宿主列表获取成功") + end_user_ids = [str(user.id) for user in end_users] - + # 并发执行两个独立的查询任务 async def get_memory_configs(): """获取记忆配置(在线程池中执行同步查询)""" @@ -116,7 +118,7 @@ async def get_workspace_end_users( except Exception as e: api_logger.error(f"批量获取记忆配置失败: {str(e)}") return {} - + async def get_memory_nums(): """获取记忆数量""" if current_workspace_type == "rag": @@ -130,26 +132,18 @@ async def get_workspace_end_users( except Exception as e: api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") return {uid: {"total": 0} for uid in end_user_ids} - + elif current_workspace_type == "neo4j": - # Neo4j 模式:并发查询(带并发限制) - # 使用信号量限制并发数,避免大量用户时压垮 Neo4j - MAX_CONCURRENT_QUERIES = 10 - semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES) - - async def get_neo4j_memory_num(end_user_id: str): - async with semaphore: - try: - return await memory_storage_service.search_all(end_user_id) - except Exception as e: - api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}") - return {"total": 0} - - memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids]) - return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))} - + # Neo4j 模式:批量查询(简化版本,只返回total) + try: + batch_result = await memory_storage_service.search_all_batch(end_user_ids) + return {uid: {"total": count} for uid, count in batch_result.items()} + except Exception as e: + api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}") + return {uid: {"total": 0} for uid in end_user_ids} + return {uid: {"total": 0} for uid in end_user_ids} - + # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 try: from app.celery_app import celery_app as _celery_app @@ -170,13 +164,13 @@ async def get_workspace_end_users( get_memory_configs(), get_memory_nums() ) - - # 构建结果(优化:使用列表推导式) - result = [] + + # 构建结果列表 + items = [] for end_user in end_users: user_id = str(end_user.id) config_info = memory_configs_map.get(user_id, {}) - result.append({ + items.append({ 'end_user': { 'id': user_id, 'other_name': end_user.other_name @@ -187,12 +181,6 @@ async def get_workspace_end_users( "memory_config_name": config_info.get("memory_config_name") } }) - - # 写入缓存(30秒过期) - try: - await aio_redis_set(cache_key, json.dumps(result), expire=30) - except Exception as e: - api_logger.warning(f"Redis 缓存写入失败: {str(e)}") # 触发社区聚类补全任务(异步,不阻塞接口响应) try: @@ -202,7 +190,18 @@ async def get_workspace_end_users( except Exception as e: api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") - api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") + # 构建分页响应 + result = { + "items": items, + "page": { + "page": page, + "pagesize": pagesize, + "total": total, + "hasnext": (page * pagesize) < total + } + } + + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条") return success(data=result, msg="宿主列表获取成功") diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py index 2b5ef72f..51ce92b3 100644 --- a/api/app/controllers/memory_forget_controller.py +++ b/api/app/controllers/memory_forget_controller.py @@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import ( ForgettingCurveRequest, ForgettingCurveResponse, ForgettingCurvePoint, + PendingNodesResponse, ) from app.schemas.response_schema import ApiResponse from app.services.memory_forget_service import MemoryForgetService @@ -308,6 +309,100 @@ async def get_forgetting_stats( return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) +@router.get("/pending-nodes", response_model=ApiResponse) +async def get_pending_nodes( + end_user_id: str, + page: int = 1, + pagesize: int = 10, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + 此接口独立分页,与 /stats 接口分离。 + + Args: + end_user_id: 组ID(即 end_user_id,必填) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含待遗忘节点列表和分页信息的响应 + + Examples: + - 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10 + - 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20 + + Notes: + - page 从1开始,pagesize 必须大于0 + - 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}} + """ + workspace_id = current_user.current_workspace_id + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 验证 end_user_id 必填 + if not end_user_id: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id") + return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required") + + # 通过 end_user_id 获取关联的 config_id + try: + from app.services.memory_agent_service import get_end_user_connected_config + + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + config_id = resolve_config_id(config_id, db) + + if config_id is None: + api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None") + + api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}") + except ValueError as e: + api_logger.warning(f"获取终端用户配置失败: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + except Exception as e: + api_logger.error(f"获取终端用户配置时发生错误: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e)) + + # 验证分页参数 + if page < 1: + return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1") + if pagesize < 1: + return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: " + f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}" + ) + + try: + # 调用服务层获取待遗忘节点列表 + result = await forget_service.get_pending_nodes( + db=db, + end_user_id=end_user_id, + config_id=config_id, + page=page, + pagesize=pagesize + ) + + # 构建响应 + response_data = PendingNodesResponse(**result) + + return success(data=response_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"获取待遗忘节点列表失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e)) + + @router.post("/forgetting_curve", response_model=ApiResponse) async def get_forgetting_curve( request: ForgettingCurveRequest, diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index f5284b46..c10ad14b 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.workflow_service import WorkflowService +from app.models.file_metadata_model import FileMetadata from app.utils.app_config_utils import workflow_config_4_app_release, \ agent_config_4_app_release, multi_agent_config_4_app_release @@ -259,8 +260,41 @@ def get_conversation( conv_service = ConversationService(db) messages = conv_service.get_messages(conversation_id) - # 构建响应 - conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump() + file_ids = [] + message_file_id_map = {} + + # 第一次遍历:解析 audio_url,收集所有有效的 file_id + for idx, m in enumerate(messages): + if m.role == "assistant" and m.meta_data: + audio_url = m.meta_data.get("audio_url") + if not audio_url: + continue + try: + file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1]) + except (ValueError, IndexError): + # audio_url 无法解析为 UUID,标记为 unknown + m.meta_data["audio_status"] = "unknown" + continue + + file_ids.append(file_id) + message_file_id_map[idx] = file_id + + # 批量查询所有相关的 FileMetadata + file_status_map = {} + if file_ids: + file_metas = ( + db.query(FileMetadata) + .filter(FileMetadata.id.in_(set(file_ids))) + .all() + ) + file_status_map = {fm.id: fm.status for fm in file_metas} + + # 第二次遍历:将查询结果映射回消息 + for idx, file_id in message_file_id_map.items(): + m = messages[idx] + m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown") + + conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json") conv_dict["messages"] = [ conversation_schema.Message.model_validate(m) for m in messages ] @@ -320,6 +354,16 @@ async def chat( other_id=other_id, original_user_id=user_id ) + + # Only extract and set memory_config_id when the end user doesn't have one yet + if not new_end_user.memory_config_id: + from app.services.memory_config_service import MemoryConfigService + memory_config_service = MemoryConfigService(db) + memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {}) + if memory_config_id: + new_end_user.memory_config_id = memory_config_id + db.commit() + db.refresh(new_end_user) end_user_id = str(new_end_user.id) # appid = share.app_id @@ -410,30 +454,6 @@ async def chat( agent_config = agent_config_4_app_release(release) if payload.stream: - # async def event_generator(): - # async for event in service.chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) async def event_generator(): async for event in app_chat_service.agnet_chat_stream( message=payload.message, @@ -459,20 +479,6 @@ async def chat( "X-Accel-Buffering": "no" } ) - # 非流式返回 - # result = await service.chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - # return success(data=conversation_schema.ChatResponse(**result)) result = await app_chat_service.agnet_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID @@ -531,48 +537,6 @@ async def chat( ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) - # 多 Agent 流式返回 - # if payload.stream: - # async def event_generator(): - # async for event in service.multi_agent_chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) - - # # 多 Agent 非流式返回 - # result = await service.multi_agent_chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - - # return success(data=conversation_schema.ChatResponse(**result)) elif app_type == AppType.WORKFLOW: config = workflow_config_4_app_release(release) if not config.id: diff --git a/api/app/controllers/service/__init__.py b/api/app/controllers/service/__init__.py index 8c679c1f..96da0949 100644 --- a/api/app/controllers/service/__init__.py +++ b/api/app/controllers/service/__init__.py @@ -4,7 +4,7 @@ 认证方式: API Key """ from fastapi import APIRouter -from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller +from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller # 创建 V1 API 路由器 service_router = APIRouter() @@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router) service_router.include_router(rag_api_file_controller.router) service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(memory_api_controller.router) +service_router.include_router(end_user_api_controller.router) __all__ = ["service_router"] diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 32a911f9..d4573464 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -91,7 +91,7 @@ async def chat( app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) other_id = payload.user_id - workspace_id = app.workspace_id + workspace_id = api_key_auth.workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=app.id, diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py new file mode 100644 index 00000000..9d410bd2 --- /dev/null +++ b/api/app/controllers/service/end_user_api_controller.py @@ -0,0 +1,92 @@ +"""End User 服务接口 - 基于 API Key 认证""" + +import uuid + +from fastapi import APIRouter, Body, Depends, Request +from sqlalchemy.orm import Session + +from app.core.api_key_auth import require_api_key +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.repositories.end_user_repository import EndUserRepository +from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse +from app.services.memory_config_service import MemoryConfigService + +router = APIRouter(prefix="/end_user", tags=["V1 - End User API"]) +logger = get_business_logger() + + +@router.post("/create") +@require_api_key(scopes=["memory"]) +async def create_end_user( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(..., description="Request body"), +): + """ + Create or retrieve an end user for the workspace. + + Creates a new end user and connects it to a memory configuration. + If an end user with the same other_id already exists in the workspace, + returns the existing one. + + Optionally accepts a memory_config_id to connect the end user to a specific + memory configuration. If not provided, falls back to the workspace default config. + """ + body = await request.json() + payload = CreateEndUserRequest(**body) + workspace_id = api_key_auth.workspace_id + + logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}") + + # Resolve memory_config_id: explicit > workspace default + memory_config_id = None + config_service = MemoryConfigService(db) + + if payload.memory_config_id: + try: + memory_config_id = uuid.UUID(payload.memory_config_id) + except ValueError: + raise BusinessException( + f"Invalid memory_config_id format: {payload.memory_config_id}", + BizCode.INVALID_PARAMETER + ) + config = config_service.get_config_with_fallback(memory_config_id, workspace_id) + if not config: + raise BusinessException( + f"Memory config not found: {payload.memory_config_id}", + BizCode.MEMORY_CONFIG_NOT_FOUND + ) + memory_config_id = config.config_id + else: + default_config = config_service.get_workspace_default_config(workspace_id) + if default_config: + memory_config_id = default_config.config_id + logger.info(f"Using workspace default memory config: {memory_config_id}") + else: + logger.warning(f"No default memory config found for workspace: {workspace_id}") + + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_or_create_end_user_with_config( + app_id=api_key_auth.resource_id, + workspace_id=workspace_id, + other_id=payload.other_id, + memory_config_id=memory_config_id, + ) + + logger.info(f"End user ready: {end_user.id}") + + result = { + "id": str(end_user.id), + "other_id": end_user.other_id or "", + "other_name": end_user.other_name or "", + "workspace_id": str(end_user.workspace_id), + "memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None, + } + + return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 16213690..cc16a6b4 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -111,6 +111,18 @@ def get_current_user_info( break api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") + + # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限 + if current_user.external_source: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() + if source and source.permissions: + result_schema.permissions = source.permissions + else: + result_schema.permissions = [] + else: + result_schema.permissions = ["all"] + return success(data=result_schema, msg=t("users.info.get_success")) @@ -135,7 +147,6 @@ def get_tenant_superusers( return success(data=superusers_schema, msg=t("users.list.superusers_success")) - @router.get("/{user_id}", response_model=ApiResponse) def get_user_info_by_id( user_id: uuid.UUID, diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 9776cc29..3bb2252f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,18 +11,14 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term -from app.db import get_db -from app.core.logging_config import get_business_logger -from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType, ModelProvider -from app.services.memory_agent_service import ( - get_end_user_connected_config, -) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType + logger = get_business_logger() @@ -226,10 +222,9 @@ class LangChainAgent: Returns: List[BaseMessage]: 消息列表 """ - messages = [] + messages:list = [SystemMessage(content=self.system_prompt)] # 添加系统提示词 - messages.append(SystemMessage(content=self.system_prompt)) # 添加历史消息 if history: @@ -320,12 +315,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, # 添加这个参数 - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: """执行对话 @@ -333,32 +323,12 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] context: 上下文信息(如知识库检索结果) + files: 多模态文件 Returns: Dict: 包含 content 和元数据的字典 """ - message_chat = message start_time = time.time() - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -445,9 +415,6 @@ class LangChainAgent: logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, - actual_config_id) response = { "content": content, "model": self.model_name, @@ -478,12 +445,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> AsyncGenerator[str | int, None]: """执行流式对话 @@ -491,6 +453,7 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 context: 上下文信息 + files: 多模态文件 Yields: str: 消息内容块 @@ -501,23 +464,6 @@ class LangChainAgent: logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info("=" * 80) - message_chat = message - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -527,17 +473,18 @@ class LangChainAgent: ) chunk_count = 0 - yielded_content = False # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") full_content = '' try: + last_event = {} async for event in self.agent.astream_events( {"messages": messages}, version="v2", config={"recursion_limit": self.max_iterations} ): + last_event = event chunk_count += 1 kind = event.get("event") @@ -551,7 +498,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -562,18 +508,15 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif kind == "on_llm_stream": # 另一种 LLM 流式事件 @@ -584,7 +527,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -595,22 +537,18 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif isinstance(chunk, str): full_content += chunk yield chunk - yielded_content = True # 记录工具调用(可选) elif kind == "on_tool_start": @@ -620,17 +558,14 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 - # 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages - output_messages = event.get("data", {}).get("output", {}).get("messages", []) + output_messages = last_event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): stream_total_tokens = self._extract_tokens_from_message(msg) logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}") yield stream_total_tokens break - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, - actual_config_id) + except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 2074b6ca..74fb6bae 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id @@ -21,25 +20,6 @@ logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - """ - Write messages to RAG storage system - - Combines user and AI messages into a single string format and stores them - in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. - - Args: - end_user_id: User identifier for the conversation - user_message: User's input message content - ai_message: AI's response message content - user_rag_memory_id: RAG memory identifier for storage location - """ - # RAG mode: combine messages into string format (maintain original logic) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - - async def write( storage_type, end_user_id, @@ -118,7 +98,7 @@ async def write( logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') -async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): +async def term_memory_save(end_user_id, strategy_type, scope): """ Save long-term memory data to database @@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty to long-term memory storage. Args: - long_term_messages: Long-term message data to be saved - actual_config_id: Configuration identifier for memory settings end_user_id: User identifier for memory association - type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) scope: Scope/window size for memory processing """ with get_db_context() as db_session: @@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if not result: + logger.warning(f"No write data found for user {end_user_id}") + return + if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]: data = await format_parsing(result, "dict") chunk_data = data[:scope] if len(chunk_data) == scope: @@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty logger.info(f'写入短长期:') -"""Window-based dialogue processing""" - - async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): """ Process dialogue based on window size and write to Neo4j @@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) langchain_messages: Original message data list scope: Window size determining when to trigger long-term storage """ - scope = scope - is_end_user_id = count_store.get_sessions_count(end_user_id) - if is_end_user_id is not False: - is_end_user_id = count_store.get_sessions_count(end_user_id)[0] - redis_messages = count_store.get_sessions_count(end_user_id)[1] - if is_end_user_id and int(is_end_user_id) != int(scope): - is_end_user_id += 1 - langchain_messages += redis_messages - count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) - elif int(is_end_user_id) == int(scope): + is_end_user_has_history = count_store.get_sessions_count(end_user_id) + if is_end_user_has_history: + end_user_visit_count, redis_messages = is_end_user_has_history + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + return + end_user_visit_count += 1 + if end_user_visit_count < scope: + redis_messages.extend(langchain_messages) + count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages) + else: logger.info('写入长期记忆NEO4J') - formatted_messages = redis_messages + redis_messages.extend(langchain_messages) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: config_id = memory_config - await write( - AgentMemory_Long_Term.STORAGE_NEO4J, - end_user_id, - "", - "", - None, - end_user_id, - config_id, - formatted_messages + write_message_task.delay( + end_user_id, # end_user_id: User ID + redis_messages, # message: JSON string format message list + config_id, # config_id: Configuration ID string + AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) - count_store.update_sessions_count(end_user_id, 1, langchain_messages) - else: - count_store.save_sessions_count(end_user_id, 1, langchain_messages) - - -"""Time-based memory processing""" + count_store.update_sessions_count(end_user_id, 0, []) async def memory_long_term_storage(end_user_id, memory_config, time): @@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config return result_dict except Exception as e: - print(f"[aggregate_judgment] 发生错误: {e}") - import traceback - traceback.print_exc() + logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True) return { "is_same_event": False, diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index bf3c6597..32fc7d8a 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,49 +1,25 @@ -import asyncio -import json -import sys import warnings -from contextlib import asynccontextmanager -from langgraph.constants import END, START -from langgraph.graph import StateGraph -from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment +from app.core.memory.agent.utils.redis_tool import write_store +from app.db import get_db_context from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService +from app.services.memory_konwledges_server import write_rag warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) -if sys.platform.startswith("win"): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -@asynccontextmanager -async def make_write_graph(): - """ - Create a write graph workflow for memory operations. - - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - end_user_id: Group identifier - memory_config: MemoryConfig object containing all configuration - """ - workflow = StateGraph(WriteState) - workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "save_neo4j") - workflow.add_edge("save_neo4j", END) - - graph = workflow.compile() - - yield graph - - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): +async def long_term_storage( + long_term_type: str, + langchain_messages: list, + memory_config_id: str, + end_user_id: str, + scope: int = 6 +): """ Handle long-term memory storage with different strategies @@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l Args: long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') langchain_messages: List of messages to store - memory_config: Memory configuration identifier + memory_config_id: Memory configuration identifier end_user_id: User group identifier scope: Scope parameter for chunk-based storage (default: 6) """ - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ - aggregate_judgment - from app.core.memory.agent.utils.redis_tool import write_store + if langchain_messages is None: + langchain_messages = [] + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 + config_id=memory_config_id, # 改为整数 service_name="MemoryAgentService" ) if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: - '''Strategy 1: Dialogue window with 6 rounds of conversation''' + # Dialogue window with 6 rounds of conversation await window_dialogue(end_user_id, langchain_messages, memory_config, scope) if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: - """Time-based strategy""" + # Time-based strategy await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: - """Strategy 3: Aggregate judgment""" + # Aggregate judgment await aggregate_judgment(end_user_id, langchain_messages, memory_config) -async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): +async def write_long_term( + storage_type: str, + end_user_id: str, + messages: list[dict], + user_rag_memory_id: str, + actual_config_id: str +): """ Write long-term memory with different storage types @@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u Args: storage_type: Type of storage (RAG or traditional) end_user_id: User group identifier - message_chat: User message content - aimessages: AI response messages + messages: message list user_rag_memory_id: RAG memory identifier actual_config_id: Actual configuration ID """ - from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + message_content = [] + for message in messages: + message_content.append(f'{message.get("role")}:{message.get("content")}') + messages_string = "\n".join(message_content) + await write_rag(end_user_id, messages_string, user_rag_memory_id) else: # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE - long_term_messages = await agent_chat_messages(message_chat, aimessages) - await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) - await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) - -# async def main(): -# """主函数 - 运行工作流""" -# langchain_messages = [ -# { -# "role": "user", -# "content": "今天周五去爬山" -# }, -# { -# "role": "assistant", -# "content": "好耶" -# } -# -# ] -# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID -# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# -# -# -# if __name__ == "__main__": -# import asyncio -# asyncio.run(main()) + await long_term_storage(long_term_type=CHUNK, + langchain_messages=messages, + memory_config_id=actual_config_id, + end_user_id=end_user_id, + scope=SCOPE) + await term_memory_save(end_user_id, CHUNK, scope=SCOPE) diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index c5729628..82b22c9e 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -3,8 +3,9 @@ import uuid from app.core.config import settings from typing import List, Dict, Any, Optional, Union +from app.core.logging_config import get_logger from app.core.memory.agent.utils.redis_base import ( - serialize_messages, + serialize_messages, deserialize_messages, fix_encoding, format_session_data, @@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import ( get_current_timestamp ) - +logger = get_logger(__name__) class RedisWriteStore: """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -66,10 +67,10 @@ class RedisWriteStore: }) result = pipe.execute() - print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session_write] 保存会话失败: {e}") + logger.error(f"[save_session_write] 保存会话失败: {e}") raise e def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: @@ -99,7 +100,7 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid: # 从 key 中提取 session_id: session:write:{session_id} @@ -108,16 +109,16 @@ class RedisWriteStore: "sessionid": session_id, "messages": fix_encoding(data.get('messages', '')) }) - + if not results: return False - - print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_session_by_userid] 查询失败: {e}") + logger.error(f"[get_session_by_userid] 查询失败: {e}") return False - + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: """ 通过 end_user_id 获取所有 write 类型的会话数据 @@ -144,7 +145,7 @@ class RedisWriteStore: # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") return False # 批量获取数据 @@ -158,12 +159,12 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == end_user_id: # 从 key 中提取 session_id: session:write:{session_id} session_id = key.split(':')[-1] - + # 构建完整的会话信息 session_info = { "session_id": session_id, @@ -173,23 +174,21 @@ class RedisWriteStore: "starttime": data.get('starttime', '') } results.append(session_info) - + if not results: - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") return False - + # 按时间排序(最新的在前) results.sort(key=lambda x: x.get('starttime', ''), reverse=True) - - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") - import traceback - traceback.print_exc() + logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True) return False - def find_user_recent_sessions(self, userid: str, + def find_user_recent_sessions(self, userid: str, minutes: int = 5) -> List[Dict[str, str]]: """ 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 @@ -203,11 +202,11 @@ class RedisWriteStore: """ import time start_time = time.time() - + # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -221,7 +220,7 @@ class RedisWriteStore: for data in all_data: if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid and data.get('starttime'): # write 类型没有 aimessages,所以 Answer 为空 @@ -230,15 +229,14 @@ class RedisWriteStore: "Answer": "", "starttime": data.get('starttime', '') }) - + # 根据时间范围过滤 filtered_items = filter_by_time_range(matched_items, minutes) # 排序并移除时间字段 - result_items = sort_and_limit_results(filtered_items, limit=None) - print(result_items) + result_items = sort_and_limit_results(filtered_items) elapsed_time = time.time() - start_time - print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items @@ -258,7 +256,7 @@ class RedisWriteStore: class RedisCountStore: """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -278,7 +276,7 @@ class RedisCountStore: decode_responses=True, encoding='utf-8' ) - self.uudi = session_id + self.uuid = session_id def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: """ @@ -295,26 +293,26 @@ class RedisCountStore: session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") index_key = f'session:count:index:{end_user_id}' # 索引键 - + pipe = self.r.pipeline() pipe.hset(key, mapping={ - "id": self.uudi, + "id": self.uuid, "end_user_id": end_user_id, "count": int(count), "messages": serialize_messages(messages), "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 - + # 创建索引:end_user_id -> session_id 映射 pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) - + result = pipe.execute() - - print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + + logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") return session_id - def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool: """ 通过 end_user_id 查询访问次数统计 @@ -327,7 +325,7 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) @@ -335,35 +333,40 @@ class RedisCountStore: self.r.delete(index_key) return False except Exception as type_error: - print(f"[get_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: return False - + # 直接获取数据 key = generate_session_key(session_id, key_type="count") data = self.r.hgetall(key) - + if not data: # 索引存在但数据不存在,清理索引 self.r.delete(index_key) return False - + count = data.get('count') messages_str = data.get('messages') - + if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] - + messages: list[dict] = deserialize_messages(messages_str) + return int(count), messages + return False except Exception as e: - print(f"[get_sessions_count] 查询失败: {e}") + logger.error(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, - messages: Any) -> bool: + + def update_sessions_count( + self, + end_user_id: str, + new_count: int, + messages: Any + ) -> bool: """ 通过 end_user_id 修改访问次数统计(优化版:使用索引) @@ -378,39 +381,39 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) if key_type != 'string' and key_type != 'none': # 索引键类型错误,删除并返回 False - print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") self.r.delete(index_key) - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False except Exception as type_error: - print(f"[update_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False - + # 直接更新数据 key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - + pipe = self.r.pipeline() - pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'count', str(new_count)) pipe.hset(key, 'messages', messages_str) result = pipe.execute() - - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + + logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") return True - + except Exception as e: - print(f"[update_sessions_count] 更新失败: {e}") + logger.debug(f"[update_sessions_count] 更新失败: {e}") return False def delete_all_count_sessions(self) -> int: @@ -428,7 +431,7 @@ class RedisCountStore: class RedisSessionStore: """Redis 会话存储类,用于管理会话数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -451,9 +454,9 @@ class RedisSessionStore: self.uudi = session_id # ==================== 写入操作 ==================== - - def save_session(self, userid: str, messages: str, aimessages: str, - apply_id: str, end_user_id: str) -> str: + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: """ 写入一条会话数据,返回 session_id @@ -483,14 +486,14 @@ class RedisSessionStore: }) result = pipe.execute() - print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session] 保存会话失败: {e}") + logger.error(f"[save_session] 保存会话失败: {e}") raise e # ==================== 读取操作 ==================== - + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 @@ -520,8 +523,8 @@ class RedisSessionStore: sessions[sid] = self.get_session(sid) return sessions - def find_user_apply_group(self, sessionid: str, apply_id: str, - end_user_id: str) -> List[Dict[str, str]]: + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: """ 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 @@ -535,10 +538,10 @@ class RedisSessionStore: """ import time start_time = time.time() - + keys = self.r.keys('session:*') if not keys: - print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -556,21 +559,21 @@ class RedisSessionStore: continue if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配或完全匹配 sessionid if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append(format_session_data(data, include_time=True)) - + # 排序、限制数量并移除时间字段 result_items = sort_and_limit_results(matched_items, limit=6) elapsed_time = time.time() - start_time - print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items # ==================== 更新操作 ==================== - + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 @@ -591,7 +594,7 @@ class RedisSessionStore: return bool(results[0]) # ==================== 删除操作 ==================== - + def delete_session(self, session_id: str) -> int: """ 删除单条会话 @@ -632,7 +635,7 @@ class RedisSessionStore: keys = self.r.keys('session:*') if not keys: - print("[delete_duplicate_sessions] 没有会话数据") + logger.debug("[delete_duplicate_sessions] 没有会话数据") return 0 # 批量获取所有数据 @@ -678,7 +681,7 @@ class RedisSessionStore: deleted_count += len(batch) elapsed_time = time.time() - start_time - print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") + logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 55bcb8ba..1f437973 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -151,11 +151,6 @@ async def write( # Step 3: Save all data to Neo4j database step_start = time.time() - from app.repositories.neo4j.create_indexes import create_fulltext_indexes - try: - await create_fulltext_indexes() - except Exception as e: - logger.error(f"Error creating indexes: {e}", exc_info=True) # 添加死锁重试机制 max_retries = 3 @@ -279,5 +274,21 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + # Close LLM/Embedder underlying httpx clients to prevent + # 'RuntimeError: Event loop is closed' during garbage collection + for client_obj in (llm_client, embedder_client): + try: + underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None) + if underlying is None: + continue + # Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model + inner = getattr(underlying, '_model', underlying) + # LangChain OpenAI models expose async_client (httpx.AsyncClient) + http_client = getattr(inner, 'async_client', None) + if http_client is not None and hasattr(http_client, 'aclose'): + await http_client.aclose() + except Exception: + pass + logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/llm_tools/llm_client.py b/api/app/core/memory/llm_tools/llm_client.py index e26aba3e..49cd9434 100644 --- a/api/app/core/memory/llm_tools/llm_client.py +++ b/api/app/core/memory/llm_tools/llm_client.py @@ -56,7 +56,7 @@ class LLMClient(ABC): self.max_retries = self.config.max_retries self.timeout = self.config.timeout - logger.info( + logger.debug( f"初始化 LLM 客户端: provider={self.provider}, " f"model={self.model_name}, max_retries={self.max_retries}" ) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..c70fef5f 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.info(f"OpenAI 客户端初始化完成: type={type_}") + logger.debug(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 967f529e..5390197a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene logger = logging.getLogger(__name__) +def message_has_files(message: "ConversationMessage") -> bool: + """检查消息是否包含文件。 + + Args: + message: 待检查的消息对象 + + Returns: + bool: 如果消息包含文件则返回 True,否则返回 False + """ + return message.files and len(message.files) > 0 + + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -128,7 +140,7 @@ class SemanticPruner: 1. 空消息 2. 场景特定填充词库精确匹配 3. 常见寒暄精确匹配 - 4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了") + 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了") 5. 纯表情/标点 """ t = message.msg.strip() @@ -482,6 +494,11 @@ class SemanticPruner: """ to_delete_ids: set = set() for m in msgs: + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}") + continue + # 填充检测优先:先判断是否为填充,再看 LLM 保护 if self._is_filler_message(m): to_delete_ids.add(id(m)) @@ -549,6 +566,11 @@ class SemanticPruner: to_delete_ids: set = set() for m in msgs: msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}") + continue # 第一优先级:填充消息无论模式直接删除,不参与后续场景判断 if self._is_filler_message(m): @@ -801,6 +823,12 @@ class SemanticPruner: for idx, m in enumerate(msgs): msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与分类 + if message_has_files(m): + self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}") + llm_protected_msgs.append((idx, m)) # 放入保护列表 + continue if self._msg_matches_tokens(m, preserve_tokens): llm_protected_msgs.append((idx, m)) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 5ef7db0e..b20112a2 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -182,7 +182,7 @@ class ExtractionOrchestrator: list[StatementEntityEdge], list[EntityEntityEdge], list[PerceptualEdge], - dict + list[DialogData] ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -295,6 +295,7 @@ class ExtractionOrchestrator: statement_entity_edges, entity_entity_edges, dialog_data_list, + dedup_details, ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -306,6 +307,11 @@ class ExtractionOrchestrator: dialog_data_list, ) + # 步骤 7: 同步用户别名到数据库表(仅正式模式) + if not is_pilot_run: + logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表") + await self._update_end_user_other_name(entity_nodes, dialog_data_list) + logger.info(f"知识提取流水线运行完成({mode_str})") return ( dialogue_nodes, @@ -1399,7 +1405,8 @@ class ExtractionOrchestrator: logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") else: first_alias = current_aliases[0].strip() if current_aliases else "" - if first_alias: + # 确保 first_alias 不是占位名称 + if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES: db.add(EndUserInfo( end_user_id=end_user_uuid, other_name=first_alias, @@ -1415,29 +1422,33 @@ class ExtractionOrchestrator: + # 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中 + USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'} + def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) - 这个方法直接返回 LLM 提取的别名列表,不做任何修改。 + 这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。 第一个别名将被用作 other_name。 Args: entity_nodes: 实体节点列表 Returns: - 别名列表(保持 LLM 提取的原始顺序) + 别名列表(保持 LLM 提取的原始顺序,已过滤占位名称) """ - USER_NAMES = {'用户', '我', 'User', 'I'} for entity in entity_nodes: - if getattr(entity, 'name', '').strip() in USER_NAMES: + if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES: aliases = getattr(entity, 'aliases', []) or [] - logger.debug(f"提取到用户别名(原始顺序): {aliases}") - return aliases + # 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}") + return filtered return [] async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: - """从 Neo4j 查询用户实体的完整 aliases 列表""" + """从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)""" cypher = """ MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] @@ -1451,7 +1462,10 @@ class ExtractionOrchestrator: aliases = result[0].get('aliases') or [] if not aliases: logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") - return aliases + return [] + # 过滤掉占位名称,防止历史脏数据传播 + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + return filtered def _resolve_other_name( self, @@ -1463,14 +1477,25 @@ class ExtractionOrchestrator: 决定 other_name 是否需要更新,返回新值;无需更新返回 None。 决策规则: - - 为空 → 用本次对话第一个别名 + - 为空或为占位名称 → 用本次对话第一个别名 - 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除) - 否则 → 保持不变(返回 None) + + 注意:返回值不允许是占位名称("用户"、"我"、"User"、"I") """ - if not current or not current.strip(): - return current_aliases[0].strip() if current_aliases else None + # 当前值为空或为占位名称时,需要更新 + if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES: + candidate = current_aliases[0].strip() if current_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate if current not in neo4j_aliases: - return neo4j_aliases[0].strip() if neo4j_aliases else None + candidate = neo4j_aliases[0].strip() if neo4j_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate return None @@ -1492,6 +1517,7 @@ class ExtractionOrchestrator: list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[DialogData], dict ]: """ @@ -1555,6 +1581,8 @@ class ExtractionOrchestrator: statement_chunk_edges, dedup_statement_entity_edges, dedup_entity_entity_edges, + dialog_data_list, + dedup_details, ) final_entity_nodes = dedup_entity_nodes @@ -1562,7 +1590,16 @@ class ExtractionOrchestrator: final_entity_entity_edges = dedup_entity_entity_edges else: # 正式模式:执行完整的两阶段去重 - result_tuple = await dedup_layers_and_merge_and_return( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dedup_details, + ) = await dedup_layers_and_merge_and_return( dialogue_nodes, chunk_nodes, statement_nodes, @@ -1576,21 +1613,21 @@ class ExtractionOrchestrator: llm_client=self.llm_client, ) - # 解包返回值 - ( - _, - _, - _, - final_entity_nodes, - _, - final_statement_entity_edges, - final_entity_entity_edges, - dedup_details, - ) = result_tuple - # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) + result_tuple = ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dialog_data_list, + dedup_details, + ) + logger.info( f"去重后: {len(final_entity_nodes)} 个实体节点, " f"{len(final_statement_entity_edges)} 条陈述句-实体边, " diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index f9f2f45c..6605532d 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement. {% if language == "zh" %} - 用户实体的 name 字段:使用 "用户" 或 "我" - 用户的真实姓名:放入 aliases + - **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等** - 示例: * "我叫李明" → name="用户", aliases=["李明"] + * ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases) + * ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases) {% else %} - User entity name field: use "User" or "I" - User's real name: put in aliases + - **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.** - Examples: * "I'm John" → name="User", aliases=["John"] + * ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases) + * ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases) {% endif %} diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 1db86fef..c6c6ec48 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -44,6 +44,8 @@ class OSSStorage(StorageBackend): access_key_id: str, access_key_secret: str, bucket_name: str, + connect_timeout: int = 30, + multipart_threshold: int = 10 * 1024 * 1024, # 10MB ): """ Initialize the OSSStorage backend. @@ -53,6 +55,8 @@ class OSSStorage(StorageBackend): access_key_id: The Aliyun access key ID. access_key_secret: The Aliyun access key secret. bucket_name: The name of the OSS bucket. + connect_timeout: Connection timeout in seconds (default: 30). + multipart_threshold: File size threshold for multipart upload (default: 10MB). Raises: StorageConfigError: If any required configuration is missing. @@ -69,10 +73,17 @@ class OSSStorage(StorageBackend): self.endpoint = endpoint self.bucket_name = bucket_name + self.multipart_threshold = multipart_threshold try: auth = oss2.Auth(access_key_id, access_key_secret) - self.bucket = oss2.Bucket(auth, endpoint, bucket_name) + # 设置超时和重试 + self.bucket = oss2.Bucket( + auth, + endpoint, + bucket_name, + connect_timeout=connect_timeout + ) logger.info( f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}" ) @@ -108,21 +119,38 @@ class OSSStorage(StorageBackend): if content_type: headers["Content-Type"] = content_type - self.bucket.put_object(file_key, content, headers=headers if headers else None) + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB per part + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers if headers else None) + logger.info(f"File uploaded to OSS successfully: {file_key}") return file_key except OssError as e: logger.error(f"OSS error uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e.message}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -135,28 +163,73 @@ class OSSStorage(StorageBackend): ) -> int: """Upload from async stream to OSS. Returns total bytes written.""" buf = io.BytesIO() + headers = {"Content-Type": content_type} if content_type else None + upload_id = None + try: + # 收集流数据 + total_size = 0 async for chunk in stream: + if not chunk: + continue buf.write(chunk) + total_size += len(chunk) + content = buf.getvalue() - headers = {"Content-Type": content_type} if content_type else None - self.bucket.put_object(file_key, content, headers=headers) - logger.info(f"File stream uploaded to OSS successfully: {file_key}") - return len(content) + + if not content: + raise StorageUploadError( + message="Empty stream content", + file_key=file_key, + ) + + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers) + + logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)") + return total_size + except OssError as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"OSS error stream uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e.message}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"Failed to stream upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) + finally: + buf.close() async def download(self, file_key: str) -> bytes: """ @@ -182,14 +255,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error downloading file {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e.message}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to download file from OSS {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -215,14 +288,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error deleting file {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e.message}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to delete file from OSS {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 2da0d3a8..eed44278 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType def merge_activate_state(x, y): - return { - k: x.get(k, False) or y.get(k, False) - for k in set(x) | set(y) - } + merged = dict(x) + for k, v in y.items(): + merged[k] = merged.get(k, False) or v + return merged def merge_looping_state(x, y): diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 60f1257e..7faca82d 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta logger = logging.getLogger(__name__) +VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}") + + +class LazyVariableDict: + def __init__(self, source, literal): + self._source: dict[str, VariableStruct[Any]] = source + self._literal: bool = literal + self._cache = {} + + def keys(self): + return self._source.keys() + + def _resolve(self, key): + if key in self._cache: + return self._cache[key] + var_struct = self._source.get(key) + if var_struct is None: + raise KeyError(key) + value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value() + self._cache[key] = value + return value + + def get(self, key, default=None): + try: + return self._resolve(key) + except KeyError: + return default + + def __getitem__(self, key): + return self._resolve(key) + + def __getattr__(self, key): + if key.startswith('_'): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + return self._resolve(key) + + def __contains__(self, key): + return key in self._source + + def __iter__(self): + return iter(self._source) + + def __len__(self): + return len(self._source) + class VariableSelector: """变量选择器 @@ -117,8 +162,7 @@ class VariablePool: @staticmethod def transform_selector(selector): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() + variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip() selector = VariableSelector.from_string(variable_literal).path if len(selector) != 2: raise ValueError(f"Selector not valid - {selector}") @@ -303,6 +347,16 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None + def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict: + return LazyVariableDict(self.variables.get(namespace, {}), literal) + + def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]: + return { + ns: LazyVariableDict(vars_dict, literal) + for ns, vars_dict in self.variables.items() + if ns not in ("sys", "conv") + } + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 @@ -479,5 +533,3 @@ class VariablePoolInitializer: var_type=var_type, mut=False ) - - diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8567ebbe..bedf6165 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -552,9 +552,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(literal=True), - node_outputs=variable_pool.get_all_node_outputs(literal=True), - system_vars=variable_pool.get_all_system_vars(literal=True), + conv_vars=variable_pool.lazy_namespace("conv", literal=True), + node_outputs=variable_pool.lazy_all_node_outputs(literal=True), + system_vars=variable_pool.lazy_namespace("sys", literal=True), strict=strict ) @@ -579,9 +579,9 @@ class BaseNode(ABC): return evaluate_condition( expression=expression, - conv_var=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars() + conv_var=variable_pool.lazy_namespace("conv"), + node_outputs=variable_pool.lazy_all_node_outputs(), + system_vars=variable_pool.lazy_namespace("sys") ) @staticmethod diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 84901bad..e555a228 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance -from app.core.workflow.utils.expression_evaluator import evaluate_expression logger = logging.getLogger(__name__) @@ -85,12 +84,7 @@ class LoopRuntime: for variable in self.typed_config.cycle_vars: if variable.input_type == ValueInputType.VARIABLE: - value = evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + value = self.variable_pool.get_value(variable.value) else: value = TypeTransformer.transform(variable.value, variable.type) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) @@ -98,12 +92,7 @@ class LoopRuntime: **self.state ) loopstate["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + variable.name: self.variable_pool.get_value(variable.value) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index 40641f3c..bd828760 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode): # Reuse cached bytes if already fetched if f.get_content(): file_input.set_content(f.get_content()) - text = await svc._extract_document_text(file_input) + text = await svc.extract_document_text(file_input) chunks.append(text) except Exception as e: logger.error( diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 92699cb4..d0b6d098 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -1,19 +1,23 @@ +import asyncio import logging import uuid from typing import Any +from langchain_core.documents import Document + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model, ModelType -from app.repositories import knowledge_repository, knowledgeshare_repository +from app.models import knowledge_model, ModelType +from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services.model_service import ModelConfigService @@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None - self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode): unique.append(doc) return unique - def _get_existing_kb_ids(self, db, kb_ids): + def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: """ - Resolve all accessible and valid knowledge base IDs for retrieval. - - This includes: - - Private knowledge bases owned by the user - - Shared knowledge bases - - Source knowledge bases mapped via knowledge sharing relationships - + Reorder the list of document blocks and return the top_k results most relevant to the query Args: - db: Database session. - kb_ids (list[UUID]): Knowledge base IDs from node configuration. + query: query string + docs: List of document chunk to be rearranged + top_k: The number of top-level documents returned Returns: - list[UUID]: Final list of valid knowledge base IDs. + Rearranged document chunk list (sorted in descending order of relevance) + + Raises: + ValueError: If the input document list is empty or top_k is invalid """ - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) - - existing_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) - - share_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids) + reranker = self.get_reranker_model() + # parameter validation + if not docs: + raise ValueError("retrieval chunks be empty") + if top_k <= 0: + raise ValueError("top_k must be a positive integer") + try: + # Convert to LangChain Document object + documents = [ + Document( + page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute + metadata=doc.metadata or {} # Deal with possible None metadata + ) + for doc in docs ] - items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters + + # Perform reordering (compress_documents will automatically handle relevance scores and indexing) + reranked_docs = list(reranker.compress_documents(documents, query)) + + # Sort in descending order based on relevance score + reranked_docs.sort( + key=lambda x: x.metadata.get("relevance_score", 0), + reverse=True ) - existing_ids.extend(items) - return existing_ids + # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] + result = [] + for item in reranked_docs[:top_k]: + for doc in docs: + if doc.page_content == item.page_content: + doc.metadata["score"] = item.metadata["relevance_score"] + result.append(doc) + return result + except Exception as e: + raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e def get_reranker_model(self) -> RedBearRerank: """ @@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + async def knowledge_retrieval(self, db, query, db_knowledge, kb_config): + rs = [] if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + tasks = [] for child in children: if not (child and child.chunk_num > 0 and child.status == 1): continue - kb_config.kb_id = child.id - self.knowledge_retrieval(db, query, rs, child, kb_config) - return - self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + child_kb_config = kb_config.model_copy() + child_kb_config.kb_id = child.id + tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) + return rs + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) indices = f"Vector_index_{kb_config.kb_id}_Node".lower() match kb_config.retrieve_type: case RetrieveType.PARTICIPLE: - rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + ) case RetrieveType.SEMANTIC: - rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + ) case RetrieveType.HYBRID: - rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) + rs1_task = asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + rs2_task = asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) if not unique_rs: - return + return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + rs.extend( + await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k} + ) + ) else: rs.extend(sorted( unique_rs, @@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode): )[:kb_config.top_k]) case _: raise RuntimeError("Unknown retrieval type") + return rs async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ @@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode): knowledge_bases = self.typed_config.knowledge_bases rs = [] + tasks = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") - self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) + tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) if not rs: return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + final_rs = await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k} + ) else: final_rs = sorted( rs, diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 4bc5fc4c..05a3294b 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -4,32 +4,33 @@ from typing import Any from simpleeval import simple_eval, NameNotDefined, InvalidExpression +from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class ExpressionEvaluator: """Safe expression evaluator for workflow variables and node outputs.""" - + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @classmethod def normalize_template(cls, template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @classmethod def evaluate( - cls, - expression: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + cls, + expression: str, + conv_vars: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> Any: """ Safely evaluate an expression using workflow variables. @@ -49,48 +50,47 @@ class ExpressionEvaluator: # Remove Jinja2-style brackets if present expression = expression.strip() expression = cls.normalize_template(expression) - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", expression).strip() + expression = VARIABLE_PATTERN.sub(r"\1", expression).strip() # Build context for evaluation context = { - "conv": conv_vars, # conversation variables - "node": node_outputs, # node outputs - "sys": system_vars or {}, # system variables + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - context.update(conv_vars) - context["nodes"] = node_outputs + # context.update(conv_vars) + # context["nodes"] = node_outputs context.update(node_outputs) - + try: # simpleeval supports safe operations: # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result - + except NameNotDefined as e: logger.error(f"Undefined variable in expression: {expression}, error: {e}") raise ValueError(f"Undefined variable: {e}") - + except InvalidExpression as e: logger.error(f"Invalid expression syntax: {expression}, error: {e}") raise ValueError(f"Invalid expression syntax: {e}") - + except SyntaxError as e: logger.error(f"Syntax error in expression: {expression}, error: {e}") raise ValueError(f"Syntax error: {e}") - + except Exception as e: logger.error(f"Expression evaluation failed: {expression}, error: {e}") raise ValueError(f"Expression evaluation failed: {e}") - + @staticmethod def evaluate_bool( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + expression: str, + conv_var: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> bool: """ Evaluate a boolean expression (for conditions). @@ -108,7 +108,7 @@ class ExpressionEvaluator: expression, conv_var, node_outputs, system_vars ) return bool(result) - + @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: """ @@ -121,7 +121,7 @@ class ExpressionEvaluator: list[str]: List of error messages. Empty if all names are valid. """ errors = [] - + for var in variables: var_name = var.get("name", "") @@ -134,16 +134,16 @@ class ExpressionEvaluator: errors.append( f"Variable name '{var_name}' is not a valid Python identifier" ) - + return errors # 便捷函数 def evaluate_expression( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict ) -> Any: """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( @@ -152,11 +152,11 @@ def evaluate_expression( def evaluate_condition( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None -) -> bool: + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict +) -> Any: """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( expression, conv_var, node_outputs, system_vars diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 6a73efc4..2c2d0f67 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -1,7 +1,8 @@ """ -模板渲染器 +Template Renderer -使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 +Provides safe template rendering using Jinja2, supporting variable references +and expressions. """ import logging @@ -10,11 +11,15 @@ from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined +from app.core.workflow.engine.variable_pool import LazyVariableDict + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class SafeUndefined(Undefined): - """访问未定义属性不会报错,返回空字符串""" + """Return empty string instead of raising error when accessing undefined variables""" __slots__ = () def _fail_with_undefined_error(self, *args, **kwargs): @@ -26,26 +31,22 @@ class SafeUndefined(Undefined): class TemplateRenderer: - """模板渲染器""" - def __init__(self, strict: bool = True): - """初始化渲染器 - + """Initialize renderer + Args: - strict: 是否使用严格模式(未定义变量会抛出异常) + strict: Whether to enable strict mode (raise error on undefined variables) """ self.strict = strict self.env = Environment( undefined=StrictUndefined if strict else SafeUndefined, - autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML + autoescape=False # Disable auto-escaping since we handle plain text instead of HTML ) @staticmethod def normalize_template(template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + """Normalize template syntax (convert numeric node reference to dict access)""" + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @@ -53,24 +54,24 @@ class TemplateRenderer: def render( self, template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict | None = None ) -> str: - """渲染模板 - + """Render template + Args: - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Raises: - ValueError: 模板语法错误或变量未定义 - + ValueError: If template syntax is invalid or variables are undefined + Examples: >>> renderer = TemplateRenderer() >>> renderer.render( @@ -80,122 +81,119 @@ class TemplateRenderer: ... {} ... ) 'Hello World!' - + >>> renderer.render( - ... "分析结果: {{node.analyze.output}}", + ... "Analysis result: {{node.analyze.output}}", ... {}, - ... {"analyze": {"output": "正面情绪"}}, + ... {"analyze": {"output": "positive sentiment"}}, ... {} ... ) - '分析结果: 正面情绪' + 'Analysis result: positive sentiment' """ - # 构建命名空间上下文 + # Build namespace context context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": system_vars, # 系统变量:{{sys.execution_id}} + "conv": conv_vars, # Conversation variables: {{conv.user_name}} + "node": node_outputs, # Node outputs: {{node.node_1.output}} + "sys": system_vars, # System variables: {{sys.execution_id}} } - # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} - # 将所有节点输出添加到顶层上下文 + # Allow direct access to node outputs by node ID: {{llm_qa.output}} if node_outputs: context.update(node_outputs) - # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} - if conv_vars: - context.update(conv_vars) - - context["nodes"] = node_outputs or {} # 旧语法兼容 + # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} + # if conv_vars: + # context.update(conv_vars) + # + # context["nodes"] = node_outputs or {} # 旧语法兼容 template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) except TemplateSyntaxError as e: - logger.error(f"模板语法错误: {template}, 错误: {e}") - raise ValueError(f"模板语法错误: {e}") - + logger.error(f"Template syntax error: {template}, error: {e}") + raise ValueError(f"Template syntax error: {e}") except UndefinedError as e: - logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") - + logger.error(f"Undefined variable in template: {template}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except Exception as e: - logger.error(f"模板渲染异常: {template}, 错误: {e}") - raise ValueError(f"模板渲染失败: {e}") + logger.error(f"Template rendering error: {template}, error: {e}") + raise ValueError(f"Template rendering failed: {e}") def validate(self, template: str) -> list[str]: - """验证模板语法 - + """Validate template syntax + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表,如果为空则验证通过 - + List of errors (empty if valid) + Examples: >>> renderer = TemplateRenderer() >>> renderer.validate("Hello {{var.name}}!") [] - - >>> renderer.validate("Hello {{var.name") # 缺少结束标记 - ['模板语法错误: ...'] + + >>> renderer.validate("Hello {{var.name") # Missing closing tag + ['Template syntax error: ...'] """ errors = [] try: self.env.from_string(template) except TemplateSyntaxError as e: - errors.append(f"模板语法错误: {e}") + errors.append(f"Template syntax error: {e}") except Exception as e: - errors.append(f"模板验证失败: {e}") + errors.append(f"Template validation failed: {e}") return errors -# 全局渲染器实例(严格模式) +# Global renderer instances (strict / lenient) _strict_renderer = TemplateRenderer(strict=True) _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any], + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: - """渲染模板(便捷函数) - + """Render template (convenience function) + Args: - strict: 严格模式 - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + strict: Whether to use strict mode + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Examples: >>> render_template( - ... "请分析: {{var.text}}", - ... {"text": "这是一段文本"}, + ... "Analyze: {{var.text}}", + ... {"text": "This is a text"}, ... {}, ... {} ... ) - '请分析: 这是一段文本' + 'Analyze: This is a text' """ renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: - """验证模板语法(便捷函数) - + """Validate template syntax (convenience function) + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表 + List of errors """ return _strict_renderer.validate(template) diff --git a/api/app/main.py b/api/app/main.py index f4c23ca8..9e501f11 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,5 +1,6 @@ import os import subprocess +from app.repositories.neo4j.create_indexes import create_all_indexes from contextlib import asynccontextmanager from fastapi import FastAPI, APIRouter @@ -60,8 +61,10 @@ async def lifespan(app: FastAPI): logger.warning(f"加载预定义模型时出错: {str(e)}") else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") - + await create_all_indexes() logger.info("应用程序启动完成") + + yield # 应用关闭事件 logger.info("应用程序正在关闭") diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index 81319789..c0b17d14 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -19,9 +19,12 @@ class User(Base): last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空 # SSO 外部关联字段 - external_id = Column(String(100), nullable=True) # 外部用户ID + external_id = Column(String(100), nullable=True) # 外部用户 ID external_source = Column(String(50), nullable=True) # 来源系统 + # 用户联系方式 + phone = Column(String(50), nullable=True) # 用户电话 + # 用户语言偏好 preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文 diff --git a/api/app/repositories/conversation_repository.py b/api/app/repositories/conversation_repository.py index 90f2d6ec..0676a255 100644 --- a/api/app/repositories/conversation_repository.py +++ b/api/app/repositories/conversation_repository.py @@ -199,6 +199,96 @@ class ConversationRepository: ) return conversations, total + def list_app_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + is_draft: Optional[bool] = None, + page: int = 1, + pagesize: int = 20 + ) -> tuple[list[Conversation], int]: + """ + 查询应用日志会话列表(带分页和过滤) + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + is_draft: 是否草稿会话(None 表示不过滤) + page: 页码(从 1 开始) + pagesize: 每页数量 + + Returns: + Tuple[List[Conversation], int]: (会话列表,总数) + """ + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + # Calculate total number of records + total = int(self.db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + # Apply pagination + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(self.db.scalars(stmt).all()) + + logger.info( + "Listed app conversations successfully", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "returned": len(conversations), + "total": total + } + ) + return conversations, total + + def get_conversation_for_app_log( + self, + conversation_id: uuid.UUID, + app_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询应用日志的会话详情 + + Args: + conversation_id: 会话 ID + app_id: 应用 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info(f"Fetching conversation for app log: {conversation_id}") + + stmt = select(Conversation).where( + Conversation.id == conversation_id, + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + conversation = self.db.scalars(stmt).first() + + if not conversation: + logger.warning(f"Conversation not found: {conversation_id}") + raise ResourceNotFoundException("会话", str(conversation_id)) + + logger.info(f"Conversation fetched successfully: {conversation_id}") + return conversation + def soft_delete_conversation_by_conversation_id( self, conversation_id: uuid.UUID, @@ -290,6 +380,34 @@ class MessageRepository: self.db.add(message) return message + def get_messages_by_conversation( + self, + conversation_id: uuid.UUID + ) -> list[Message]: + """ + 查询会话的所有消息(按时间正序) + + Args: + conversation_id: 会话 ID + + Returns: + List[Message]: 消息列表 + """ + stmt = select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at) + + messages = list(self.db.scalars(stmt).all()) + + logger.info( + "Fetched messages for conversation", + extra={ + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + return messages + def get_message_by_conversation_id( self, conversation_id: uuid.UUID, diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 3c1dd16f..aad80707 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -132,6 +132,82 @@ class EndUserRepository: db_logger.error(f"获取或创建终端用户时出错: {str(e)}") raise + def get_or_create_end_user_with_config( + self, + app_id: Optional[uuid.UUID], + workspace_id: uuid.UUID, + other_id: str, + memory_config_id: Optional[uuid.UUID] = None, + other_name: Optional[str] = None + ) -> EndUser: + """获取或创建终端用户,并在单次事务中关联记忆配置。 + + 与 get_or_create_end_user 类似,但额外支持在创建/获取时 + 一并设置 memory_config_id,避免多次提交。 + + Args: + app_id: 应用ID(可为 None) + workspace_id: 工作空间ID + other_id: 第三方ID + memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置) + other_name: 用户名称(用于创建 EndUserInfo) + + Returns: + EndUser: 终端用户对象(已关联记忆配置) + """ + try: + end_user = ( + self.db.query(EndUser) + .filter( + EndUser.workspace_id == workspace_id, + EndUser.other_id == other_id + ) + .order_by(EndUser.created_at.asc()) + .first() + ) + + if end_user: + db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}") + if app_id is not None: + end_user.app_id = app_id + if memory_config_id and not end_user.memory_config_id: + end_user.memory_config_id = memory_config_id + self.db.commit() + self.db.refresh(end_user) + return end_user + + # 创建新用户 + end_user = EndUser( + app_id=app_id, + workspace_id=workspace_id, + other_id=other_id, + memory_config_id=memory_config_id, + ) + self.db.add(end_user) + self.db.flush() + + end_user_info = EndUserInfo( + end_user_id=end_user.id, + other_name=other_name or "", + aliases=[], + meta_data={} + ) + self.db.add(end_user_info) + + self.db.commit() + self.db.refresh(end_user) + + db_logger.info( + f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, " + f"memory_config_id={memory_config_id}" + ) + return end_user + + except Exception as e: + self.db.rollback() + db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}") + raise + def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据ID获取终端用户(用于缓存操作) @@ -515,6 +591,51 @@ class EndUserRepository: ) raise + def batch_update_memory_config_id_by_app( + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID + ) -> int: + """批量更新应用下所有终端用户的 memory_config_id + + Args: + app_id: 应用ID + memory_config_id: 新的记忆配置ID + + Returns: + int: 更新的终端用户数量 + + Raises: + Exception: 数据库操作失败时抛出 + """ + try: + from sqlalchemy import update + + stmt = ( + update(EndUser) + .where(EndUser.app_id == app_id) + .values(memory_config_id=memory_config_id) + ) + + result = self.db.execute(stmt) + self.db.commit() + + updated_count = result.rowcount + + db_logger.info( + f"批量更新终端用户记忆配置: app_id={app_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + + return updated_count + except Exception as e: + self.db.rollback() + db_logger.error( + f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + f"memory_config_id={memory_config_id}, error={str(e)}" + ) + raise + def count_by_memory_config_id( self, memory_config_id: uuid.UUID diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index e64d19a3..3139b851 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -78,6 +78,15 @@ class MemoryConfigRepository: OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count """ + # 批量查询多个用户的记忆数量(简化版本,只返回total) + SEARCH_FOR_ALL_BATCH = """ + MATCH (n) WHERE n.end_user_id IN $end_user_ids + RETURN + n.end_user_id as user_id, + count(n) as total + ORDER BY user_id + """ + # Extracted entity details within group/app/user SEARCH_FOR_DETIALS = """ MATCH (n:ExtractedEntity) diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index d9e94117..5132aa09 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -1,62 +1,47 @@ +import asyncio from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - async def create_fulltext_indexes(): """Create full-text indexes for keyword search with BM25 scoring.""" connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Full-Text Indexes (for keyword search)") - print("=" * 70) + # 创建 Statements 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: statementsFulltext") + """) # # 创建 Dialogues 索引 # await connector.execute_query(""" # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } # """) - # 创建 Entities 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: entitiesFulltext") + """) # 创建 Chunks 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: chunksFulltext") + """) # 创建 MemorySummary 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: summariesFulltext") - + """) # 创建 Community 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) - print("✓ Created: communitiesFulltext") - print("\nFull-text indexes created successfully with BM25 support.") - except Exception as e: - print(f"✗ Error creating full-text indexes: {e}") finally: await connector.close() - - async def create_vector_indexes(): """Create vector indexes for fast embedding similarity search. @@ -65,12 +50,7 @@ async def create_vector_indexes(): """ connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Vector Indexes (for embedding search)") - print("=" * 70) - print("Note: Adjust vector.dimensions if using different embedding model") - print(" Current setting: 1024 dimensions (for bge-m3)") - print() + # Statement embedding index await connector.execute_query(""" @@ -82,7 +62,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: statement_embedding_index") + # Chunk embedding index await connector.execute_query(""" @@ -94,7 +74,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: chunk_embedding_index") + # Entity name embedding index await connector.execute_query(""" @@ -106,7 +86,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: entity_embedding_index") + # Memory summary embedding index await connector.execute_query(""" @@ -118,8 +98,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: summary_embedding_index") - + # Community summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS @@ -129,8 +108,7 @@ async def create_vector_indexes(): `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """) - print("✓ Created: community_summary_embedding_index") + """) # Dialogue embedding index (optional) await connector.execute_query(""" @@ -142,91 +120,15 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: dialogue_embedding_index") - - # Community summary embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS - FOR (c:Community) - ON c.summary_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: community_summary_embedding_index") - print("\nVector indexes created successfully!") - print("\nExpected performance improvement:") - print(" Before: ~1.4s for embedding search") - print(" After: ~0.05-0.2s for embedding search (10-30x faster!)") - - except Exception as e: - print(f"✗ Error creating vector indexes: {e}") finally: await connector.close() - - -async def create_config_id_indexes(): - """Create indexes on config_id fields for improved query performance. - - These indexes enable fast filtering of nodes by configuration ID, - which is essential for configuration isolation and multi-tenant scenarios. - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Config ID Indexes") - print("=" * 70) - - # Dialogue.config_id index - await connector.execute_query(""" - CREATE INDEX dialogue_config_id_index IF NOT EXISTS - FOR (d:Dialogue) ON (d.config_id) - """) - print("✓ Created: dialogue_config_id_index") - - # Statement.config_id index - await connector.execute_query(""" - CREATE INDEX statement_config_id_index IF NOT EXISTS - FOR (s:Statement) ON (s.config_id) - """) - print("✓ Created: statement_config_id_index") - - # ExtractedEntity.config_id index - await connector.execute_query(""" - CREATE INDEX entity_config_id_index IF NOT EXISTS - FOR (e:ExtractedEntity) ON (e.config_id) - """) - print("✓ Created: entity_config_id_index") - - # MemorySummary.config_id index - await connector.execute_query(""" - CREATE INDEX summary_config_id_index IF NOT EXISTS - FOR (m:MemorySummary) ON (m.config_id) - """) - print("✓ Created: summary_config_id_index") - - print("\nConfig ID indexes created successfully!") - print("These indexes enable fast filtering by configuration ID.") - - except Exception as e: - print(f"✗ Error creating config_id indexes: {e}") - finally: - await connector.close() - - async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. - Ensures concurrent MERGE operations remain safe and prevents duplicates. """ connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Unique Constraints") - print("=" * 70) - + try: # Dialogue.id unique await connector.execute_query( """ @@ -234,8 +136,7 @@ async def create_unique_constraints(): FOR (d:Dialogue) REQUIRE d.id IS UNIQUE """ ) - print("✓ Created: dialog_id_unique") - + # Statement.id unique await connector.execute_query( """ @@ -243,8 +144,7 @@ async def create_unique_constraints(): FOR (s:Statement) REQUIRE s.id IS UNIQUE """ ) - print("✓ Created: statement_id_unique") - + # Chunk.id unique await connector.execute_query( """ @@ -252,112 +152,13 @@ async def create_unique_constraints(): FOR (c:Chunk) REQUIRE c.id IS UNIQUE """ ) - print("✓ Created: chunk_id_unique") - - print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.") - except Exception as e: - print(f"✗ Error creating unique constraints: {e}") + finally: await connector.close() - - async def create_all_indexes(): """Create all indexes and constraints in one go.""" - print("\n" + "=" * 70) - print("Neo4j Index & Constraint Setup") - print("=" * 70) - print("This will create:") - print(" 1. Full-text indexes (for keyword/BM25 search)") - print(" 2. Vector indexes (for embedding similarity search)") - print(" 3. Config ID indexes (for configuration isolation)") - print(" 4. Unique constraints (for data integrity)") - print("=" * 70) - await create_fulltext_indexes() await create_vector_indexes() - await create_config_id_indexes() await create_unique_constraints() - - print("\n" + "=" * 70) print("✓ All indexes and constraints created successfully!") - print("=" * 70) - print("\nTo verify, run in Neo4j Browser:") - print(" SHOW INDEXES") - print(" SHOW CONSTRAINTS") - print() - - -async def check_indexes(): - """Check what indexes currently exist.""" - connector = Neo4jConnector() - - try: - print("\n" + "=" * 70) - print("Checking Existing Indexes") - print("=" * 70) - query = "SHOW INDEXES" - result = await connector.execute_query(query) - - fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT'] - vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR'] - range_indexes = [idx for idx in result if idx.get('type') == 'RANGE'] - - print(f"\nFull-text indexes: {len(fulltext_indexes)}") - for idx in fulltext_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nVector indexes: {len(vector_indexes)}") - for idx in vector_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nRange indexes (including config_id): {len(range_indexes)}") - for idx in range_indexes: - print(f" ✓ {idx.get('name')}") - - if not vector_indexes: - print("\n⚠️ WARNING: No vector indexes found!") - print(" Embedding search will be VERY SLOW (~1.4s)") - print(" Run: python create_indexes.py") - - # Check for config_id indexes - config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')] - if len(config_id_indexes) < 4: - print("\n⚠️ WARNING: Not all config_id indexes found!") - print(f" Expected 4, found {len(config_id_indexes)}") - print(" Run: python create_indexes.py config_id") - - print("=" * 70) - - finally: - await connector.close() - - -if __name__ == "__main__": - import asyncio - import sys - - if len(sys.argv) > 1: - command = sys.argv[1] - if command == "check": - asyncio.run(check_indexes()) - elif command == "fulltext": - asyncio.run(create_fulltext_indexes()) - elif command == "vector": - asyncio.run(create_vector_indexes()) - elif command == "config_id": - asyncio.run(create_config_id_indexes()) - elif command == "constraints": - asyncio.run(create_unique_constraints()) - else: - print(f"Unknown command: {command}") - print("\nUsage:") - print(" python create_indexes.py # Create all indexes") - print(" python create_indexes.py check # Check existing indexes") - print(" python create_indexes.py fulltext # Create only full-text indexes") - print(" python create_indexes.py vector # Create only vector indexes") - print(" python create_indexes.py config_id # Create only config_id indexes") - print(" python create_indexes.py constraints # Create only constraints") - else: - asyncio.run(create_all_indexes()) - diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c08f9d0e..26ffe350 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WITH e, score -UNION -MATCH (e:ExtractedEntity) -WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) - AND e.aliases IS NOT NULL - AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) -WITH e, +WITH collect({entity: e, score: score}) AS fulltextResults + +OPTIONAL MATCH (ae:ExtractedEntity) +WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id) + AND ae.aliases IS NOT NULL + AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q)) +WITH fulltextResults, collect(ae) AS aliasEntities + +UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score: CASE - WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 - WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 ELSE 0.8 - END AS score + END +}]) AS row +WITH row.entity AS e, row.score AS score WITH DISTINCT e, MAX(score) AS score OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index 3f8919aa..af4449e5 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -158,22 +158,26 @@ class UserRepository: raise def get_users_by_tenant( - self, - tenant_id: uuid.UUID, - skip: int = 0, + self, + tenant_id: uuid.UUID, + skip: int = 0, limit: int = 100, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> List[User]: """获取租户下的用户列表""" db_logger.debug(f"查询租户用户: tenant_id={tenant_id}") - + try: query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -181,7 +185,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + users = query.offset(skip).limit(limit).all() db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}") return users @@ -190,18 +194,22 @@ class UserRepository: raise def count_users_by_tenant( - self, + self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" try: query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -209,7 +217,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + return query.scalar() except Exception as e: db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}") diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index 4e24faa0..a783fe3f 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -3,9 +3,9 @@ """ import uuid -from typing import Any, Annotated +from typing import Any, Annotated, Literal from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, select from fastapi import Depends from app.models.workflow_model import ( @@ -128,29 +128,36 @@ class WorkflowExecutionRepository: Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.app_id == app_id ).order_by( desc(WorkflowExecution.started_at) - ).limit(limit).offset(offset).all() + ).limit(limit).offset(offset) + return list(self.db.execute(stmt).scalars()) def get_by_conversation_id( self, - conversation_id: uuid.UUID + conversation_id: uuid.UUID, + status: Literal["running", "completed", "failed"] = None, + limit_count: int = 50 ) -> list[WorkflowExecution]: """根据会话 ID 获取执行记录列表 Args: + limit_count: conversation_id: 会话 ID + status: 状态(可选) Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.conversation_id == conversation_id - ).order_by( - desc(WorkflowExecution.started_at) - ).all() + ) + if status: + stmt = stmt.filter(WorkflowExecution.status == status) + stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count) + return list(self.db.execute(stmt).scalars()) def count_by_app_id(self, app_id: uuid.UUID) -> int: """统计应用的执行次数 @@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表(按执行顺序排序) """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id ).order_by( WorkflowNodeExecution.execution_order - ).all() + ) + return list(self.db.execute(stmt).scalars()) def get_by_node_id( self, @@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表 """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.node_id == node_id ).order_by( WorkflowNodeExecution.retry_count - ).all() + ) + return list(self.db.execute(stmt).scalars()) # ==================== 依赖注入函数 ==================== diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index e34945eb..f1e9132f 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel): # 记忆配置 memory: MemoryConfig = Field( - default_factory=lambda: MemoryConfig(enabled=True), + default_factory=lambda: MemoryConfig(enabled=False), description="对话历史记忆配置" ) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b4efe61d..97aa5bb5 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -17,6 +17,7 @@ class Write_UserInput(BaseModel): end_user_id: str config_id: Optional[str] = None + class AgentMemory_Long_Term(ABC): """长期记忆配置常量""" STORAGE_NEO4J = "neo4j" @@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 - TIME_SCOPE=5 -class AgentMemoryDataset(ABC): - PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余'] - NAME='用户' + TIME_SCOPE = 5 + +class AgentMemoryDataset(ABC): + PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余'] + NAME = '用户' diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index 84a34e8a..ff62355f 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel): """Request schema for creating an end user. Attributes: - workspace_id: Workspace ID (required) other_id: External user identifier (required) other_name: Display name for the end user + memory_config_id: Optional memory config ID. If not provided, uses workspace default. """ - workspace_id: str = Field(..., description="Workspace ID (required)") other_id: str = Field(..., description="External user identifier (required)") other_name: Optional[str] = Field("", description="Display name") - - @field_validator("workspace_id") - @classmethod - def validate_workspace_id(cls, v: str) -> str: - """Validate that workspace_id is not empty.""" - if not v or not v.strip(): - raise ValueError("workspace_id is required and cannot be empty") - return v.strip() + memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.") @field_validator("other_id") @classmethod @@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel): other_id: External user identifier other_name: Display name workspace_id: Workspace the user belongs to + memory_config_id: Connected memory config ID """ id: str = Field(..., description="End user UUID") other_id: str = Field(..., description="External user identifier") other_name: str = Field("", description="Display name") workspace_id: str = Field(..., description="Workspace ID") + memory_config_id: Optional[str] = Field(None, description="Connected memory config ID") class MemoryConfigItem(BaseModel): diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 711b6de9..bfcf6337 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel): last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)") +class PageInfo(BaseModel): + """分页信息模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + page: int = Field(..., description="当前页码(从1开始)") + pagesize: int = Field(..., description="每页数量") + total: int = Field(..., description="总记录数") + hasnext: bool = Field(..., description="是否有下一页") + + +class PendingNodesResponse(BaseModel): + """待遗忘节点列表响应模型(独立分页接口)""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表") + page: PageInfo = Field(..., description="分页信息") + + class ForgettingStatsResponse(BaseModel): """遗忘引擎统计信息响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel): node_distribution: Dict[str, int] = Field(..., description="节点类型分布") recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") - pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py index 6b880696..aa9ac256 100644 --- a/api/app/schemas/user_schema.py +++ b/api/app/schemas/user_schema.py @@ -1,6 +1,6 @@ from dataclasses import field from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict -from typing import Optional +from typing import Optional, List import datetime import uuid @@ -20,6 +20,7 @@ class UserCreate(UserBase): class UserUpdate(BaseModel): username: Optional[str] = None email: Optional[EmailStr] = None + phone: Optional[str] = None is_active: Optional[bool] = None is_superuser: Optional[bool] = None @@ -85,6 +86,8 @@ class User(UserBase): current_workspace_name: Optional[str] = None role: Optional[WorkspaceRole] = None preferred_language: Optional[str] = "zh" # 用户语言偏好 + phone: Optional[str] = None # 用户电话 + permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制 # 将 datetime 转换为毫秒时间戳 @validator("created_at", pre=True) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index b5f9f194..a3ba860c 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent from app.core.logging_config import get_business_logger +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig @@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import AgentRunService +from app.services.memory_agent_service import get_end_user_connected_config from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService -from app.schemas import FileType logger = get_business_logger() @@ -43,18 +44,17 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, - user_id: Optional[str] = None, + files: list[FileInput], + user_id: str, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" start_time = time.time() - config_id = None # 应用 features 配置 features_config: dict = config.features or {} @@ -93,7 +93,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, + user_id) tools.extend(kb_tools) memory_flag = False if memory: @@ -140,13 +141,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( @@ -168,11 +169,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -229,6 +225,21 @@ class AppChatService: # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": result["content"]} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -264,20 +275,19 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, + files: list[FileInput], user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """聊天(流式)""" try: start_time = time.time() - config_id = None message_id = uuid.uuid4() # 应用 features 配置 @@ -319,7 +329,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config( + config.knowledge_retrieval, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -367,13 +378,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( @@ -411,11 +422,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): @@ -459,7 +465,7 @@ class AppChatService: # 保存消息 human_meta = { - "files":[], + "files": [], "history_files": {} } assistant_meta = { @@ -484,6 +490,22 @@ class AppChatService: if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url + + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": full_content} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -618,7 +640,6 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) - # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, diff --git a/api/app/services/app_log_service.py b/api/app/services/app_log_service.py new file mode 100644 index 00000000..856045d1 --- /dev/null +++ b/api/app/services/app_log_service.py @@ -0,0 +1,128 @@ +"""应用日志服务层""" +import uuid +from typing import Optional, Tuple +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.models.conversation_model import Conversation, Message +from app.repositories.conversation_repository import ConversationRepository, MessageRepository + +logger = get_business_logger() + + +class AppLogService: + """应用日志服务""" + + def __init__(self, db: Session): + self.db = db + self.conversation_repository = ConversationRepository(db) + self.message_repository = MessageRepository(db) + + def list_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + page: int = 1, + pagesize: int = 20, + is_draft: Optional[bool] = None, + ) -> Tuple[list[Conversation], int]: + """ + 查询应用日志会话列表 + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + page: 页码(从 1 开始) + pagesize: 每页数量 + is_draft: 是否草稿会话(None 表示不过滤) + + Returns: + Tuple[list[Conversation], int]: (会话列表,总数) + """ + logger.info( + "查询应用日志会话列表", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "page": page, + "pagesize": pagesize, + "is_draft": is_draft + } + ) + + # 使用 Repository 查询 + conversations, total = self.conversation_repository.list_app_conversations( + app_id=app_id, + workspace_id=workspace_id, + is_draft=is_draft, + page=page, + pagesize=pagesize + ) + + logger.info( + "查询应用日志会话列表成功", + extra={ + "app_id": str(app_id), + "total": total, + "returned": len(conversations) + } + ) + + return conversations, total + + def get_conversation_detail( + self, + app_id: uuid.UUID, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询会话详情(包含消息) + + Args: + app_id: 应用 ID + conversation_id: 会话 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 包含消息的会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info( + "查询应用日志会话详情", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "workspace_id": str(workspace_id) + } + ) + + # 查询会话 + conversation = self.conversation_repository.get_conversation_for_app_log( + conversation_id=conversation_id, + app_id=app_id, + workspace_id=workspace_id + ) + + # 查询消息(按时间正序) + messages = self.message_repository.get_messages_by_conversation( + conversation_id=conversation_id + ) + + # 将消息附加到会话对象 + conversation.messages = messages + + logger.info( + "查询应用日志会话详情成功", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + + return conversation diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 4dcabff8..377f9479 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1084,7 +1084,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned exists = self.db.query( @@ -1096,7 +1095,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned @@ -1684,15 +1682,15 @@ class AppService: return config.config_id - def _update_endusers_memory_config_by_workspace( + def _update_endusers_memory_config_by_app( self, - workspace_id: uuid.UUID, + app_id: uuid.UUID, memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id Args: - workspace_id: 工作空间ID + app_id: 应用ID memory_config_id: 新的记忆配置ID Returns: @@ -1701,8 +1699,8 @@ class AppService: from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(self.db) - updated_count = repo.batch_update_memory_config_id_by_workspace( - workspace_id=workspace_id, + updated_count = repo.batch_update_memory_config_id_by_app( + app_id=app_id, memory_config_id=memory_config_id ) @@ -1753,12 +1751,16 @@ class AppService: miss_params = [] if agent_cfg.default_model_config_id is None: - miss_params.append("model config") + miss_params.append("模型配置") if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"): - miss_params.append("memory config") + miss_params.append("记忆配置") if miss_params: - raise BusinessException(f"{', '.join(miss_params)} is required") + raise BusinessException( + f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。", + BizCode.CONFIG_MISSING, + context={"missing_params": miss_params}, + ) config = { "system_prompt": agent_cfg.system_prompt, @@ -1877,8 +1879,8 @@ class AppService: if memory_config_id: app = self.db.query(App).filter(App.id == app_id).first() if app: - updated_count = self._update_endusers_memory_config_by_workspace( - app.workspace_id, memory_config_id + updated_count = self._update_endusers_memory_config_by_app( + app_id, memory_config_id ) logger.info( f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, " @@ -2014,7 +2016,7 @@ class AppService: if memory_config_id: - updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id) + updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id) logger.info( f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " f"memory_config_id={memory_config_id}, updated_count={updated_count}" diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 014d96b7..bd7f7496 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -214,7 +214,7 @@ class ConversationService: conversation.message_count += 1 - if conversation.message_count == 1 and role == "user": + if conversation.message_count <= 2 and role == "user": conversation.title = ( content[:50] + ("..." if len(content) > 50 else "") ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index e188872f..4b503f2b 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig, ModelType +from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo @@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_service import ModelApiKeyService from app.services.multimodal_service import MultimodalService from app.services.tool_service import ToolService -from app.schemas import FileType logger = get_business_logger() @@ -449,15 +448,16 @@ class AgentRunService: features_config: Dict[str, Any], is_new_conversation: bool, variables: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + ) -> tuple[Any, Any]: """首轮对话时返回开场白文本(支持变量替换),否则返回 None""" if not is_new_conversation: - return None + return None, None opening = features_config.get("opening_statement", {}) if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): - return None + return None, None statement = opening["statement"] + suggested_questions = opening["suggested_questions"] # 如果有变量,进行替换(仅支持 {{var_name}} 格式) if variables: @@ -465,7 +465,7 @@ class AgentRunService: placeholder = f"{{{{{var_name}}}}}" statement = statement.replace(placeholder, str(var_value)) - return statement + return statement, suggested_questions @staticmethod def _filter_citations( @@ -599,13 +599,16 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -657,11 +660,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -845,14 +843,17 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, sub_agent=sub_agent, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -911,11 +912,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): @@ -1061,7 +1057,8 @@ class AgentRunService: workspace_id: uuid.UUID, user_id: Optional[str], sub_agent: bool = False, - opening_statement: Optional[str] = None + opening_statement: Optional[str] = None, + suggested_questions: Optional[List[str]] = None ) -> str: """确保会话存在(创建或验证) @@ -1072,6 +1069,7 @@ class AgentRunService: user_id: 用户ID sub_agent: 是否为子代理 opening_statement: 开场白(新会话时作为第一条消息写入) + suggested_questions: 预设问题列表 Returns: str: 会话ID @@ -1115,7 +1113,7 @@ class AgentRunService: conversation_id=uuid.UUID(new_conv_id), role="assistant", content=opening_statement, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) logger.debug(f"已保存开场白到会话 {new_conv_id}") diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 289fd74c..c27a75be 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.utils.log.audit_logger import audit_logger from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import ( ) from app.services.memory_perceptual_service import MemoryPerceptualService -try: - from app.core.memory.utils.log.audit_logger import audit_logger -except ImportError: - audit_logger = None logger = get_logger(__name__) config_logger = get_config_logger() @@ -68,24 +65,22 @@ class MemoryAgentService: if str(messages) == 'success': logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 - if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=True, - duration=duration, details={"message_length": len(message)}) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, + duration=duration, details={"message_length": len(message)}) return context else: logger.warning(f"Write operation failed for group {end_user_id}") # 记录失败的操作 - if audit_logger: - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=f"写入失败: {messages[:100]}" - ) + audit_logger.log_operation( + operation="WRITE", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=f"写入失败: {messages[:100]}" + ) raise ValueError(f"写入失败: {messages}") @@ -338,10 +333,9 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -401,10 +395,10 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) async def read_memory( @@ -469,10 +463,9 @@ class MemoryAgentService: logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") # 导入审计日志记录器 - try: - from app.core.memory.utils.log.audit_logger import audit_logger - except ImportError: - audit_logger = None + + + config_load_start = time.time() try: @@ -492,16 +485,15 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) @@ -633,15 +625,15 @@ class MemoryAgentService: total_time = time.time() - start_time logger.info( f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=True, - duration=duration - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=True, + duration=duration + ) return { "answer": summary, @@ -651,16 +643,16 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index d0078088..791a6fe8 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -1,11 +1,12 @@ from sqlalchemy.orm import Session -from typing import List, Optional +from sqlalchemy import desc, nullslast, or_, and_, cast, String +from typing import List, Optional, Dict, Any import uuid from fastapi import HTTPException from app.models.user_model import User from app.models.app_model import App -from app.models.end_user_model import EndUser +from app.models.end_user_model import EndUser, EndUser as EndUserModel from app.models.memory_increment_model import MemoryIncrement from app.repositories import ( @@ -49,44 +50,40 @@ def get_current_workspace_type( def get_workspace_end_users( - db: Session, - workspace_id: uuid.UUID, + db: Session, + workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: """获取工作空间的所有宿主(优化版本:减少数据库查询次数) - 返回结果按 created_at 从新到旧排序(NULL 值排在最后) """ business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: + + try: # 查询应用(ORM) apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - + if not apps_orm: business_logger.info("工作空间下没有应用") return [] - + # 提取所有 app_id # app_ids = [app.id for app in apps_orm] - # 批量查询所有 end_users(一次查询而非循环查询) # 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 - from app.models.end_user_model import EndUser as EndUserModel - from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( EndUserModel.workspace_id == workspace_id ).order_by( nullslast(desc(EndUserModel.created_at)), desc(EndUserModel.id) ).all() - + # 转换为 Pydantic 模型(只在需要时转换) end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] - + business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return end_users - + except HTTPException: raise except Exception as e: @@ -94,6 +91,85 @@ def get_workspace_end_users( raise +def get_workspace_end_users_paginated( + db: Session, + workspace_id: uuid.UUID, + current_user: User, + page: int, + pagesize: int, + keyword: Optional[str] = None +) -> Dict[str, Any]: + """获取工作空间的宿主列表(分页版本,支持模糊搜索) + + 返回结果按 created_at 从新到旧排序(NULL 值排在最后) + 支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + current_user: 当前用户 + page: 页码(从1开始) + pagesize: 每页数量 + keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id) + + Returns: + dict: 包含 items(宿主列表)和 total(总记录数)的字典 + """ + business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}") + + try: + # 构建基础查询 + base_query = db.query(EndUserModel).filter( + EndUserModel.workspace_id == workspace_id + ) + + # 构建搜索条件(过滤空字符串和None) + keyword = keyword.strip() if keyword else None + + if keyword: + keyword_pattern = f"%{keyword}%" + # other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效 + base_query = base_query.filter( + or_( + EndUserModel.other_name.ilike(keyword_pattern), + and_( + or_( + EndUserModel.other_name.is_(None), + EndUserModel.other_name == "", + ), + cast(EndUserModel.id, String).ilike(keyword_pattern), + ), + ) + ) + business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)") + + # 获取总记录数 + total = base_query.count() + + if total == 0: + business_logger.info("工作空间下没有宿主") + return {"items": [], "total": 0} + + # 分页查询 + # 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 + end_users_orm = base_query.order_by( + nullslast(desc(EndUserModel.created_at)), + desc(EndUserModel.id) + ).offset((page - 1) * pagesize).limit(pagesize).all() + + # 转换为 Pydantic 模型 + end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] + + business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条") + return {"items": end_users, "total": total} + + except HTTPException: + raise + except Exception as e: + business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}") + raise + + def get_workspace_memory_increment( db: Session, workspace_id: uuid.UUID, @@ -638,7 +714,24 @@ def get_rag_content( business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}") continue - # 4. 返回结果 + # 4. 将所有 page_content 拼接后按角色分割为对话列表 + merged_text = "\n".join(page_contents) + conversations = [] + if merged_text.strip(): + import re + # 在任意位置匹配 "user:" 或 "assistant:",不限于行首 + parts = re.split(r'(user|assistant):', merged_text) + # parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...] + i = 1 + while i < len(parts) - 1: + role = parts[i].strip() + content = parts[i + 1].strip() + # 将 content 中的 \n 还原为真实换行 + content = content.replace("\\n", "\n") + if role in ("user", "assistant") and content: + conversations.append({"role": role, "content": content}) + i += 2 + result = { "page": { "page": page, @@ -646,10 +739,10 @@ def get_rag_content( "total": global_total, "hasnext": offset_end < global_total, }, - "items": page_contents + "items": conversations } - business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条") + business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话") return result except Exception as e: diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 11118571..2d91f025 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -204,30 +204,35 @@ class MemoryForgetService: end_user_id: str, forgetting_threshold: float, min_days_since_access: int, - limit: int = 20 - ) -> list[Dict[str, Any]]: + page: Optional[int] = None, + pagesize: Optional[int] = None + ) -> Dict[str, Any]: """ 获取待遗忘节点列表 - - 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) - + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。 + Args: connector: Neo4j 连接器 end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 - limit: 返回节点数量限制 - + page: 页码(可选,从1开始) + pagesize: 每页数量(可选) + Returns: - list: 待遗忘节点列表 + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息(分页时) """ from datetime import timedelta - + # 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区) min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') - - query = """ + + # 基础查询(用于获取总数) + count_query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) AND n.end_user_id = $end_user_id @@ -235,10 +240,22 @@ class MemoryForgetService: AND n.activation_value < $threshold AND n.last_access_time IS NOT NULL AND datetime(n.last_access_time) < datetime($min_access_time_str) - RETURN + RETURN count(n) as total + """ + + # 数据查询 + data_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + AND n.end_user_id = $end_user_id + AND n.activation_value IS NOT NULL + AND n.activation_value < $threshold + AND n.last_access_time IS NOT NULL + AND datetime(n.last_access_time) < datetime($min_access_time_str) + RETURN elementId(n) as node_id, labels(n)[0] as node_type, - CASE + CASE WHEN n:Statement THEN n.statement WHEN n:ExtractedEntity THEN n.name WHEN n:MemorySummary THEN n.content @@ -247,18 +264,32 @@ class MemoryForgetService: n.activation_value as activation_value, n.last_access_time as last_access_time ORDER BY n.activation_value ASC - LIMIT $limit """ - + + # 如果启用分页,添加 SKIP 和 LIMIT + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + data_query += " SKIP $skip LIMIT $limit" + params = { 'end_user_id': end_user_id, 'threshold': forgetting_threshold, - 'min_access_time_str': min_access_time_str, - 'limit': limit + 'min_access_time_str': min_access_time_str } - - results = await connector.execute_query(query, **params) - + + # 获取总数(分页时需要) + total = 0 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + count_results = await connector.execute_query(count_query, **params) + if count_results: + total = count_results[0]['total'] + + # 添加分页参数 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + params['skip'] = (page - 1) * pagesize + params['limit'] = pagesize + + results = await connector.execute_query(data_query, **params) + pending_nodes = [] for result in results: # 将节点类型标签转换为小写 @@ -267,7 +298,7 @@ class MemoryForgetService: node_type_label = 'entity' elif node_type_label == 'memorysummary': node_type_label = 'summary' - + # 将 Neo4j DateTime 对象转换为时间戳(毫秒) last_access_time = result['last_access_time'] last_access_dt = convert_neo4j_datetime_to_python(last_access_time) @@ -278,7 +309,7 @@ class MemoryForgetService: last_access_timestamp = int(last_access_dt.timestamp() * 1000) else: last_access_timestamp = 0 - + pending_nodes.append({ 'node_id': str(result['node_id']), 'node_type': node_type_label, @@ -286,8 +317,20 @@ class MemoryForgetService: 'activation_value': result['activation_value'], 'last_access_time': last_access_timestamp }) - - return pending_nodes + + # 构建返回结果 + result: Dict[str, Any] = {'items': pending_nodes} + + # 如果启用分页,添加分页信息 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + result['page'] = { + 'page': page, + 'pagesize': pagesize, + 'total': total, + 'hasnext': (page * pagesize) < total + } + + return result async def trigger_forgetting_cycle( self, @@ -636,7 +679,7 @@ class MemoryForgetService: api_logger.error(f"获取历史趋势数据失败: {str(e)}") # 失败时返回空列表,不影响主流程 - # 获取待遗忘节点列表(前20个满足遗忘条件的节点) + # 获取待遗忘节点列表 pending_nodes = [] try: if end_user_id: @@ -652,8 +695,7 @@ class MemoryForgetService: connector=connector, end_user_id=end_user_id, forgetting_threshold=forgetting_threshold, - min_days_since_access=int(min_days), - limit=20 + min_days_since_access=int(min_days) ) api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") @@ -661,24 +703,79 @@ class MemoryForgetService: except Exception as e: api_logger.error(f"获取待遗忘节点失败: {str(e)}") # 失败时返回空列表,不影响主流程 - - # 构建统计信息 + + # 构建统计信息(不包含 pending_nodes,已分离到独立接口) stats = { 'activation_metrics': activation_metrics, 'node_distribution': node_distribution, 'recent_trends': recent_trends, - 'pending_nodes': pending_nodes, 'timestamp': int(datetime.now().timestamp() * 1000) } - + api_logger.info( f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " - f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" + f"trend_days={len(recent_trends)}" ) - + return stats - + + async def get_pending_nodes( + self, + db: Session, + end_user_id: str, + config_id: Optional[UUID] = None, + page: int = 1, + pagesize: int = 10 + ) -> Dict[str, Any]: + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + + Args: + db: 数据库会话 + end_user_id: 组ID(必填) + config_id: 配置ID(可选,用于获取遗忘阈值) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + + Returns: + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息 + """ + # 获取遗忘引擎组件 + _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + + connector = forgetting_scheduler.connector + forgetting_threshold = config['forgetting_threshold'] + + # 验证 min_days_since_access 配置值 + min_days = config.get('min_days_since_access') + if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: + api_logger.warning( + f"min_days_since_access 配置无效: {min_days}, 使用默认值 7" + ) + min_days = 7 + + # 调用内部方法获取分页数据 + pending_nodes_result = await self._get_pending_forgetting_nodes( + connector=connector, + end_user_id=end_user_id, + forgetting_threshold=forgetting_threshold, + min_days_since_access=int(min_days), + page=page, + pagesize=pagesize + ) + + api_logger.info( + f"成功获取待遗忘节点列表: end_user_id={end_user_id}, " + f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}" + ) + + return pending_nodes_result + async def get_forgetting_curve( self, db: Session, diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 3ee238e2..7cf94a1a 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -243,28 +243,9 @@ class MemoryPerceptualService: memory_config: MemoryConfig, file: FileInput ): - memories = self.repository.get_by_url(file.url) - if memories: - business_logger.info(f"Perceptual memory already exists: {file.url}") - if end_user_id not in [memory.end_user_id for memory in memories]: - business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") - memory_cache = memories[0] - memory = self.repository.create_perceptual_memory( - end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType(memory_cache.perceptual_type), - file_path=memory_cache.file_path, - file_name=memory_cache.file_name, - file_ext=memory_cache.file_ext, - summary=memory_cache.summary, - meta_data=memory_cache.meta_data - ) - self.db.commit() - return memory - else: - for memory in memories: - if memory.end_user_id == uuid.UUID(end_user_id): - return memory llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + if model_config is None or llm is None: + return None multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, @@ -286,15 +267,20 @@ class MemoryPerceptualService: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') - except FileNotFoundError: - raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + except FileNotFoundError as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None messages = [ {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, {"role": RoleType.USER.value, "content": [ {"type": "text", "text": "Summarize the following file"}, file_message ]} ] - result = await llm.ainvoke(messages) + try: + result = await llm.ainvoke(messages) + except Exception as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None content = result.content final_output = "" if isinstance(content, list): diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 58f3e8bd..b3a66734 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] return result +async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]: + """批量查询多个用户的记忆数量(简化版本,只返回total) + + Args: + end_user_ids: 用户ID列表 + + Returns: + Dict[str, int]: 以user_id为key的记忆数量字典 + 格式: {"user_id": total_count} + """ + if not end_user_ids: + return {} + + result = await _neo4j_connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=end_user_ids, + ) + + # 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回 + data = {} + for row in result: + data[row["user_id"]] = row["total"] + + # 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值 + for user_id in end_user_ids: + if user_id not in data: + data[user_id] = 0 + + return data + + async def analytics_hot_memory_tags( db: Session, current_user: User, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index b98674ba..c9266667 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -69,7 +69,8 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: @@ -77,21 +78,22 @@ class ModelConfigService: return model @staticmethod - def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: + def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ + ModelConfig]: """按名称模糊匹配获取模型配置列表""" return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) @staticmethod async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello", - is_omni: bool = False + db: Session, + *, + model_name: str, + provider: str, + api_key: str, + api_base: Optional[str] = None, + model_type: str = "llm", + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -158,13 +160,13 @@ class ModelConfigService: # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - + # 火山引擎使用 embed_batch,其他使用 embed_documents if provider.lower() == "volcano": vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) else: vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) - + elapsed_time = time.time() - start_time return { @@ -200,11 +202,11 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "image": # 图片生成模型验证 from app.core.models.generation import RedBearImageGenerator - + generator = RedBearImageGenerator(model_config) result = await generator.agenerate( prompt="a cute panda", @@ -212,7 +214,7 @@ class ModelConfigService: ) elapsed_time = time.time() - start_time logger.info(f"成功生成图片,结果: {result}") - + return { "valid": True, "message": "图片生成模型配置验证成功", @@ -224,21 +226,21 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "video": # 视频生成模型验证 from app.core.models.generation import RedBearVideoGenerator - + generator = RedBearVideoGenerator(model_config) result = await generator.agenerate( prompt="a cute panda playing in bamboo forest", duration=5 ) elapsed_time = time.time() - start_time - + # 视频生成是异步任务,返回任务ID task_id = result.get("task_id") if isinstance(result, dict) else None - + return { "valid": True, "message": "视频生成模型配置验证成功", @@ -265,7 +267,6 @@ class ModelConfigService: # 提取详细的错误信息 error_message = str(e) error_type = type(e).__name__ - print("=========error_message:",error_message.lower()) # 特殊处理常见的错误类型 if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): # 区域/国家限制(适用于所有提供商) @@ -354,14 +355,16 @@ class ModelConfigService: return model @staticmethod - def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """更新模型配置""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -370,25 +373,27 @@ class ModelConfigService: return model @staticmethod - async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + # 检查 API Key 关联的模型配置类型 for model_config in api_key.model_configs: # chat 和 llm 类型可以兼容 compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = model_data.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", @@ -399,7 +404,7 @@ class ModelConfigService: # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # BizCode.INVALID_PARAMETER # ) - + # 创建组合模型 model_config_data = { "tenant_id": tenant_id, @@ -418,49 +423,51 @@ class ModelConfigService: model = ModelConfigRepository.create(db, model_config_data) db.flush() - + # 关联 API Keys for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: model.api_keys.append(api_key) - + db.commit() db.refresh(model) return model @staticmethod - async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """更新组合模型""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + for model_config in api_key.model_configs: compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = existing_model.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) - + # 更新基本信息 existing_model.name = model_data.name # existing_model.type = model_data.type @@ -471,14 +478,14 @@ class ModelConfigService: existing_model.is_public = model_data.is_public if "load_balance_strategy" in model_data.model_fields_set: existing_model.load_balance_strategy = model_data.load_balance_strategy - + # 更新 API Keys 关联 existing_model.api_keys.clear() for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: existing_model.api_keys.append(api_key) - + db.commit() db.refresh(existing_model) return existing_model @@ -532,7 +539,7 @@ class ModelApiKeyService: """根据provider为多个ModelConfig创建API Key""" created_keys = [] failed_models = [] # 记录验证失败的模型 - + for model_config_id in data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: @@ -540,10 +547,10 @@ class ModelApiKeyService: data.is_omni = model_config.is_omni data.capability = model_config.capability - + # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - + # 检查是否存在API Key(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -553,7 +560,7 @@ class ModelApiKeyService: ModelApiKey.model_name == model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() - + if existing_key: # 如果已存在,重新激活并更新 if existing_key.is_active: @@ -566,14 +573,14 @@ class ModelApiKeyService: existing_key.model_name = model_name existing_key.capability = data.capability existing_key.is_omni = data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + created_keys.append(existing_key) continue - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -589,7 +596,7 @@ class ModelApiKeyService: # 记录验证失败的模型,但不抛出异常 failed_models.append(model_name) continue - + # 创建API Key api_key_data = ModelApiKeyCreate( model_config_ids=[model_config_id], @@ -606,12 +613,12 @@ class ModelApiKeyService: ) api_key_obj = ModelApiKeyRepository.create(db, api_key_data) created_keys.append(api_key_obj) - + if created_keys: db.commit() for key in created_keys: db.refresh(key) - + return created_keys, failed_models @staticmethod @@ -626,7 +633,7 @@ class ModelApiKeyService: api_key_data.is_omni = model_config.is_omni if api_key_data.capability is None: api_key_data.capability = model_config.capability - + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -650,15 +657,15 @@ class ModelApiKeyService: existing_key.model_name = api_key_data.model_name existing_key.capability = api_key_data.capability existing_key.is_omni = api_key_data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + db.commit() db.refresh(existing_key) return existing_key - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -691,7 +698,7 @@ class ModelApiKeyService: # 获取关联的模型配置以获取模型类型 if existing_api_key.model_configs: model_config = existing_api_key.model_configs[0] - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name or existing_api_key.model_name, @@ -729,15 +736,15 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: return None - + api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None - + # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) - + # 否则返回第一个 return api_keys[0] @@ -760,20 +767,19 @@ class ModelApiKeyService: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - class ModelBaseService: """基础模型服务""" @staticmethod def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: models = ModelBaseRepository.get_list(db, query) - + provider_groups = {} for m in models: model_dict = model_schema.ModelBase.model_validate(m).model_dump() if tenant_id: model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) - + provider = m.provider if provider not in provider_groups: provider_groups[provider] = { @@ -781,7 +787,7 @@ class ModelBaseService: "models": [] } provider_groups[provider]["models"].append(model_dict) - + return list(provider_groups.values()) @staticmethod @@ -823,10 +829,10 @@ class ModelBaseService: model_base = ModelBaseRepository.get_by_id(db, model_base_id) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) - + if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) - + model_config_data = { "model_id": model_base_id, "tenant_id": tenant_id, diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 4cf3d89d..2e9f809a 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -12,6 +12,9 @@ import base64 import csv import io import json +import re +import olefile +import struct import zipfile from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -438,13 +441,13 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return True, { "type": "text", - "text": f"\n{await self._extract_document_text(file)}\n" + "text": f"\n{await self.extract_document_text(file)}\n" } else: # 本地文件,提取文本内容 server_url = settings.FILE_LOCAL_SERVER_URL file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" - text = await self._extract_document_text(file) + text = await self.extract_document_text(file) file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file.upload_file_id ).first() @@ -542,7 +545,7 @@ class MultimodalService: server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" - async def _extract_document_text(self, file: FileInput) -> str: + async def extract_document_text(self, file: FileInput) -> str: """ 提取文档文本内容 @@ -602,31 +605,75 @@ class MultimodalService: try: word_file = io.BytesIO(file_content) doc = Document(word_file) - return '\n'.join(p.text for p in doc.paragraphs) + text_lines = [] + for p in doc.paragraphs: + text = p.text.strip() + if text: + text_lines.append(text) + + for table in doc.tables: + for row in table.rows: + for cell in row.cells: + text = cell.text.strip() + if text: + text_lines.append(text) + + full_text = "\n".join(text_lines) + return full_text.strip() or "[docx 文件无文本内容]" except Exception as e: - logger.error(f"提取 docx 文本失败: {e}") + logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True) return f"[docx 提取失败: {str(e)}]" - # 旧版 .doc(OLE2 格式) + # 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table try: - import olefile ole = olefile.OleFileIO(io.BytesIO(file_content)) - if not ole.exists('WordDocument'): - return "[doc 提取失败: 未找到 WordDocument 流]" - # 读取 WordDocument 流,提取可见 ASCII/Unicode 文本 - stream = ole.openstream('WordDocument').read() - # Word Binary Format: 文本在流中以 UTF-16-LE 编码存储 - # 简单提取:过滤出可打印字符段 - try: - text = stream.decode('utf-16-le', errors='ignore') - except Exception: - text = stream.decode('latin-1', errors='ignore') - # 过滤控制字符,保留可打印内容 - import re - text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) - text = re.sub(r' +', ' ', text).strip() + word_stream = ole.openstream('WordDocument').read() + + # FIB offset 0xA bit9 决定使用 0Table 还是 1Table + fib_flags = struct.unpack_from(' List[UserModel]: """获取租户下的用户列表""" @@ -155,6 +156,7 @@ class TenantService: skip=skip, limit=limit, is_active=is_active, + is_superuser=is_superuser, search=search ) @@ -162,12 +164,14 @@ class TenantService: self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" return self.user_repo.count_users_by_tenant( tenant_id=tenant_id, is_active=is_active, + is_superuser=is_superuser, search=search ) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 942e01a0..ab51d922 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -472,6 +472,21 @@ class UserMemoryService: # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} + # 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中 + _user_placeholder_names = {'用户', '我', 'User', 'I'} + + # 过滤 other_name:不允许设置为占位名称 + if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names: + logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name") + del update_data['other_name'] + + # 过滤 aliases:移除占位名称和非字符串值 + if 'aliases' in update_data and update_data['aliases']: + update_data['aliases'] = [ + a for a in update_data['aliases'] + if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names + ] + # 检查是否更新了 aliases 字段 aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index c7d7f2b1..13267078 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -561,6 +561,24 @@ class WorkflowService: storage_type = 'neo4j' return storage_type, user_rag_memory_id + def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None: + executions = self.execution_repo.get_by_conversation_id( + conversation_id=conversation_id, + status="completed", + limit_count=1 + ) + + if executions: + last_state = executions[0].output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + # input_data["conv"] = conv_vars + # input_data["conv_messages"] = last_state.get("messages") or [] + conv_messages = last_state.get("messages") or [] + return conv_vars, conv_messages + return None + # ==================== 工作流执行 ==================== async def run( @@ -634,18 +652,11 @@ class WorkflowService: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break - + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) result = await execute_workflow( @@ -807,17 +818,11 @@ class WorkflowService: storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() async for event in execute_workflow_stream( diff --git a/api/app/tasks.py b/api/app/tasks.py index d5f09a29..72421a5f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,4 @@ import asyncio -import hashlib import os import re import shutil @@ -38,12 +37,10 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema -from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService -from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisLock +from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) @@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: def set_asyncio_event_loop(): - """Set the asyncio event loop for the current thread.""" + """Ensure an open asyncio event loop exists for the current thread. + + Reuses the existing event loop if one is available and still open. + Creates and installs a new event loop only when the current one is + closed or missing (e.g. after ``_shutdown_loop_gracefully``). + """ try: loop = asyncio.get_event_loop() if loop.is_closed(): @@ -116,6 +118,30 @@ def set_asyncio_event_loop(): return loop +def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop): + """Gracefully shutdown pending async generators and tasks on the event loop. + + This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + by giving pending aclose() coroutines a chance to run before the loop is discarded. + + Note: This only tears down the given loop. Callers that need a fresh event + loop afterwards should use ``set_asyncio_event_loop()`` explicitly. + """ + try: + # Cancel and collect all remaining tasks + all_tasks = asyncio.all_tasks(loop) + if all_tasks: + for task in all_tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True)) + # Shutdown async generators (triggers __aclose__ on httpx clients etc.) + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception: + pass + finally: + loop.close() + + @celery_app.task(name="tasks.process_item") def process_item(item: dict): """ @@ -1148,8 +1174,28 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result + redis_client = get_sync_redis_client() + lock = None + if redis_client is not None: + lock = RedisFairLock( + key=f"memory_write:{end_user_id}", + redis_client=redis_client, + expire=600, + timeout=3600, + auto_renewal=True, + ) + if not lock.acquire(): + logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") + return { + "status": "SKIPPED", + "error": "acquire lock timeout", + "end_user_id": end_user_id, + "config_id": str(config_id), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + try: - # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1158,7 +1204,6 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1199,6 +1244,15 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } + finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") + # Gracefully shutdown the event loop to prevent + # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + _shutdown_loop_gracefully(loop) # unused task @@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +# unused task \ No newline at end of file diff --git a/api/app/utils/performance_timer.py b/api/app/utils/performance_timer.py new file mode 100644 index 00000000..6b0ec5d6 --- /dev/null +++ b/api/app/utils/performance_timer.py @@ -0,0 +1,37 @@ +""" +性能监控工具模块 + +提供代码块执行时间统计功能,用于接口性能分析。 +如需再次启用性能监控,只需在 controller 中导入 from app.utils.performance_timer import timer 并添加 with timer(...) 包裹需要监控的代码块即可 +""" + +import time +from contextlib import contextmanager +from app.core.logging_config import get_api_logger + +# 获取API专用日志器 +api_logger = get_api_logger() + + +@contextmanager +def timer(label: str, user_count: int = 0): + """上下文管理器:用于测量代码块执行时间 + + Args: + label: 统计标签,用于标识被测量的代码块 + user_count: 用户数,可选参数,用于记录处理的用户数量 + + Usage: + with timer("获取用户列表"): + users = get_users() + + with timer("批量处理", user_count=len(user_ids)): + process_users(user_ids) + """ + start = time.perf_counter() + try: + yield + finally: + elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 + extra_info = f", 用户数: {user_count}" if user_count > 0 else "" + api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}") diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index 99f62d84..a86ba46e 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,6 +1,7 @@ import redis import uuid import time +import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -10,45 +11,136 @@ else end """ +RENEW_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) +else + return 0 +end +""" -class RedisLock: +CLEANUP_DEAD_HEAD_SCRIPT = """ +local queue_key = KEYS[1] +local lock_key = KEYS[2] + +local first = redis.call("lindex", queue_key, 0) +if not first then + return 0 +end + +if redis.call("exists", lock_key) == 1 then + return 0 +end + +redis.call("lpop", queue_key) +return 1 +""" + +SAFE_RELEASE_QUEUE_SCRIPT = """ +local queue_key = KEYS[1] +local value = ARGV[1] + +local first = redis.call("lindex", queue_key, 0) +if first == value then + redis.call("lpop", queue_key) + return 1 +end +return 0 +""" + + +def _ensure_str(val): + """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" + if val is None: + return None + if isinstance(val, bytes): + return val.decode("utf-8") + return str(val) + + +class RedisFairLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 60, - retry_interval: float = 0.1, - timeout: float = 30 - + expire: int = 30, + retry_interval: float = 0.05, + timeout: float = 600, + auto_renewal: bool = True ): self.key = key - self.expire = expire + self.queue_key = f"{key}:queue" self.value = str(uuid.uuid4()) - self._locked = False + self.expire = expire self.retry_interval = retry_interval self.timeout = timeout - self.redis_client = redis_client + self.redis = redis_client + self._locked = False + self.auto_renewal = auto_renewal + self._renew_thread = None + self._stop_renew = threading.Event() - def acquire(self) -> bool: + def acquire(self): start = time.time() + + self.redis.rpush(self.queue_key, self.value) + while True: - ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) - if ok: - self._locked = True - return True - if time.time() - start >= self.timeout: + first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + + if first == self.value: + ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) + if ok: + self._locked = True + + if self.auto_renewal: + self._start_renewal() + return True + + if first: + self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + + if time.time() - start > self.timeout: + self.redis.lrem(self.queue_key, 0, self.value) return False + time.sleep(self.retry_interval) + def _renewal_loop(self): + while not self._stop_renew.is_set(): + time.sleep(self.expire / 3) + if self._stop_renew.is_set(): + break + + self.redis.eval( + RENEW_SCRIPT, + 1, + self.key, + self.value, + str(self.expire) + ) + + def _start_renewal(self): + self._stop_renew = threading.Event() + self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) + self._renew_thread.start() + + def _stop_renewal(self): + self._stop_renew.set() + if self._renew_thread: + self._renew_thread.join(timeout=1) + def release(self): if not self._locked: return - self.redis_client.eval( - UNLOCK_SCRIPT, - 1, - self.key, - self.value - ) + + if self.auto_renewal: + self._stop_renewal() + + self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) + + self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._locked = False def __enter__(self): @@ -59,3 +151,4 @@ class RedisLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + diff --git a/api/migrations/versions/4e89970f9e7c_202603271515.py b/api/migrations/versions/4e89970f9e7c_202603271515.py new file mode 100644 index 00000000..f37c4b27 --- /dev/null +++ b/api/migrations/versions/4e89970f9e7c_202603271515.py @@ -0,0 +1,30 @@ +"""202603271515 + +Revision ID: 4e89970f9e7c +Revises: 6b8a461148ff +Create Date: 2026-03-27 15:12:27.518344 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4e89970f9e7c' +down_revision: Union[str, None] = '6b8a461148ff' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'phone') + # ### end Alembic commands ### diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 63ec80ae..05200221 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -68,8 +68,8 @@ export const getModelTypeList = async () => { return response as any[]; }; // 获取模型列表 -export const getModelList = async (pageInfo: PageRequest) => { - const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, is_active: true }); +export const getModelList = async (types: string[], pageInfo: PageRequest) => { + const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, type: types?.join(','), is_active: true }); return response as any; }; //获取模型提供者 diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 1ec2d7dc..077cdf53 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 17:48:01 + * @Last Modified time: 2026-03-31 12:25:53 */ import { request } from '@/utils/request' import type { AxiosRequestConfig } from 'axios' @@ -63,8 +63,8 @@ export const getDashboardData = () => { /****************** User Memory APIs *******************************/ export const userMemoryListUrl = '/dashboard/end_users' -export const getUserMemoryList = () => { - return request.get(userMemoryListUrl) +export const getUserMemoryList = (query?: { keyword?: string }) => { + return request.get(userMemoryListUrl, query) } // User Memory - Total end users export const getTotalEndUsers = () => { @@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => { export const getForgetStats = (end_user_id: string) => { return request.get(`/memory/forget-memory/stats`, { end_user_id }) } +// 获取带遗忘节点列表 +export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes' // Implicit Memory - Preferences export const getImplicitPreferences = (end_user_id: string) => { return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) diff --git a/web/src/assets/images/common/delete_red_big.svg b/web/src/assets/images/common/delete_red_big.svg new file mode 100644 index 00000000..7751b4e1 --- /dev/null +++ b/web/src/assets/images/common/delete_red_big.svg @@ -0,0 +1,19 @@ + + + 编组 33 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/edit_bg.svg b/web/src/assets/images/common/edit_bg.svg new file mode 100644 index 00000000..4711afa4 --- /dev/null +++ b/web/src/assets/images/common/edit_bg.svg @@ -0,0 +1,17 @@ + + + 编辑 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/edit_bold.svg b/web/src/assets/images/common/edit_bold.svg new file mode 100644 index 00000000..c41984b2 --- /dev/null +++ b/web/src/assets/images/common/edit_bold.svg @@ -0,0 +1,16 @@ + + + 编辑 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/eye.svg b/web/src/assets/images/common/eye.svg new file mode 100644 index 00000000..c7b531b8 --- /dev/null +++ b/web/src/assets/images/common/eye.svg @@ -0,0 +1,16 @@ + + + link-outlined + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/eye_bg.svg b/web/src/assets/images/common/eye_bg.svg new file mode 100644 index 00000000..275c13c2 --- /dev/null +++ b/web/src/assets/images/common/eye_bg.svg @@ -0,0 +1,17 @@ + + + 编辑 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/link.svg b/web/src/assets/images/common/link.svg new file mode 100644 index 00000000..5773d546 --- /dev/null +++ b/web/src/assets/images/common/link.svg @@ -0,0 +1,13 @@ + + + link-outlined + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/more.svg b/web/src/assets/images/common/more.svg index 0d4d9cd2..6c24cf52 100644 --- a/web/src/assets/images/common/more.svg +++ b/web/src/assets/images/common/more.svg @@ -1,12 +1,11 @@ - - 更多 - - - - - - + + 卡片1@3x + + + + + diff --git a/web/src/assets/images/common/more_hover.svg b/web/src/assets/images/common/more_hover.svg index 04fc6eb5..d08ba08f 100644 --- a/web/src/assets/images/common/more_hover.svg +++ b/web/src/assets/images/common/more_hover.svg @@ -1,14 +1,23 @@ - - 更多 - - - - - - - - + + 更多@3x + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/images/conversation/ai.png b/web/src/assets/images/conversation/ai.png new file mode 100644 index 00000000..3783a543 Binary files /dev/null and b/web/src/assets/images/conversation/ai.png differ diff --git a/web/src/assets/images/conversation/analysisEmpty.png b/web/src/assets/images/conversation/analysisEmpty.png index 6d497f31..50adbd82 100644 Binary files a/web/src/assets/images/conversation/analysisEmpty.png and b/web/src/assets/images/conversation/analysisEmpty.png differ diff --git a/web/src/assets/images/conversation/user.png b/web/src/assets/images/conversation/user.png new file mode 100644 index 00000000..671ab044 Binary files /dev/null and b/web/src/assets/images/conversation/user.png differ diff --git a/web/src/assets/images/menuNew/arrow_t_r.svg b/web/src/assets/images/menuNew/arrow_t_r.svg new file mode 100644 index 00000000..884e46c1 --- /dev/null +++ b/web/src/assets/images/menuNew/arrow_t_r.svg @@ -0,0 +1,16 @@ + + + 编组 51 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/logout_red.svg b/web/src/assets/images/menuNew/logout_red.svg new file mode 100644 index 00000000..7057b974 --- /dev/null +++ b/web/src/assets/images/menuNew/logout_red.svg @@ -0,0 +1,17 @@ + + + 退出 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/settings.svg b/web/src/assets/images/menuNew/settings.svg new file mode 100644 index 00000000..9a64bb29 --- /dev/null +++ b/web/src/assets/images/menuNew/settings.svg @@ -0,0 +1,19 @@ + + + 设置-界面设置 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/userInfo.svg b/web/src/assets/images/menuNew/userInfo.svg new file mode 100644 index 00000000..0e67a919 --- /dev/null +++ b/web/src/assets/images/menuNew/userInfo.svg @@ -0,0 +1,13 @@ + + + 账户 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/model/bedrock.png b/web/src/assets/images/model/bedrock.png new file mode 100644 index 00000000..a16ee6f7 Binary files /dev/null and b/web/src/assets/images/model/bedrock.png differ diff --git a/web/src/assets/images/model/bedrock.svg b/web/src/assets/images/model/bedrock.svg deleted file mode 100644 index 6a0235af..00000000 --- a/web/src/assets/images/model/bedrock.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/web/src/assets/images/model/dashscope.png b/web/src/assets/images/model/dashscope.png index c1aff40e..e57821f0 100644 Binary files a/web/src/assets/images/model/dashscope.png and b/web/src/assets/images/model/dashscope.png differ diff --git a/web/src/assets/images/model/gpustack.png b/web/src/assets/images/model/gpustack.png index b154821d..39d303ae 100644 Binary files a/web/src/assets/images/model/gpustack.png and b/web/src/assets/images/model/gpustack.png differ diff --git a/web/src/assets/images/model/ollama.png b/web/src/assets/images/model/ollama.png new file mode 100644 index 00000000..068d066d Binary files /dev/null and b/web/src/assets/images/model/ollama.png differ diff --git a/web/src/assets/images/model/ollama.svg b/web/src/assets/images/model/ollama.svg deleted file mode 100644 index f8482a96..00000000 --- a/web/src/assets/images/model/ollama.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/web/src/assets/images/model/openai.png b/web/src/assets/images/model/openai.png new file mode 100644 index 00000000..db9fabaa Binary files /dev/null and b/web/src/assets/images/model/openai.png differ diff --git a/web/src/assets/images/model/openai.svg b/web/src/assets/images/model/openai.svg deleted file mode 100644 index 70686f9b..00000000 --- a/web/src/assets/images/model/openai.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/web/src/assets/images/model/volcano.png b/web/src/assets/images/model/volcano.png index 9aeb3bf3..ba0dce10 100644 Binary files a/web/src/assets/images/model/volcano.png and b/web/src/assets/images/model/volcano.png differ diff --git a/web/src/assets/images/model/xinference.png b/web/src/assets/images/model/xinference.png new file mode 100644 index 00000000..71a4821b Binary files /dev/null and b/web/src/assets/images/model/xinference.png differ diff --git a/web/src/assets/images/model/xinference.svg b/web/src/assets/images/model/xinference.svg deleted file mode 100644 index f5c5f75e..00000000 --- a/web/src/assets/images/model/xinference.svg +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/web/src/assets/images/prompt/delete.svg b/web/src/assets/images/prompt/delete.svg new file mode 100644 index 00000000..f413ffa0 --- /dev/null +++ b/web/src/assets/images/prompt/delete.svg @@ -0,0 +1,19 @@ + + + 编组 33 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/prompt/delete_hover.svg b/web/src/assets/images/prompt/delete_hover.svg new file mode 100644 index 00000000..aebdc48c --- /dev/null +++ b/web/src/assets/images/prompt/delete_hover.svg @@ -0,0 +1,20 @@ + + + 编组 33 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/prompt/edit.svg b/web/src/assets/images/prompt/edit.svg new file mode 100644 index 00000000..89668678 --- /dev/null +++ b/web/src/assets/images/prompt/edit.svg @@ -0,0 +1,16 @@ + + + 编辑 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/prompt/edit_bg.svg b/web/src/assets/images/prompt/edit_bg.svg new file mode 100644 index 00000000..4711afa4 --- /dev/null +++ b/web/src/assets/images/prompt/edit_bg.svg @@ -0,0 +1,17 @@ + + + 编辑 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/prompt/eye.svg b/web/src/assets/images/prompt/eye.svg new file mode 100644 index 00000000..df2af1cf --- /dev/null +++ b/web/src/assets/images/prompt/eye.svg @@ -0,0 +1,16 @@ + + + 编辑 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/prompt/eye_bg.svg b/web/src/assets/images/prompt/eye_bg.svg new file mode 100644 index 00000000..275c13c2 --- /dev/null +++ b/web/src/assets/images/prompt/eye_bg.svg @@ -0,0 +1,17 @@ + + + 编辑 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/clear.svg b/web/src/assets/images/workflow/clear.svg new file mode 100644 index 00000000..10502289 --- /dev/null +++ b/web/src/assets/images/workflow/clear.svg @@ -0,0 +1,13 @@ + + + clear-outlined + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/document-extractor.svg b/web/src/assets/images/workflow/document-extractor.svg new file mode 100644 index 00000000..eea39cc6 --- /dev/null +++ b/web/src/assets/images/workflow/document-extractor.svg @@ -0,0 +1,32 @@ + + + 3备份 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/features.svg b/web/src/assets/images/workflow/features.svg new file mode 100644 index 00000000..2ff48584 --- /dev/null +++ b/web/src/assets/images/workflow/features.svg @@ -0,0 +1,15 @@ + + + 参与 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/return.svg b/web/src/assets/images/workflow/return.svg new file mode 100644 index 00000000..b7cfe153 --- /dev/null +++ b/web/src/assets/images/workflow/return.svg @@ -0,0 +1,17 @@ + + + 退出 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/run.svg b/web/src/assets/images/workflow/run.svg new file mode 100644 index 00000000..5d320106 --- /dev/null +++ b/web/src/assets/images/workflow/run.svg @@ -0,0 +1,13 @@ + + + 编组 31 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/save.svg b/web/src/assets/images/workflow/save.svg new file mode 100644 index 00000000..681c7633 --- /dev/null +++ b/web/src/assets/images/workflow/save.svg @@ -0,0 +1,17 @@ + + + 保存 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/variable.svg b/web/src/assets/images/workflow/variable.svg new file mode 100644 index 00000000..cdb8338e --- /dev/null +++ b/web/src/assets/images/workflow/variable.svg @@ -0,0 +1,16 @@ + + + 聊天 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 34472a2e..3ef69136 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:17 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-26 13:32:29 + * @Last Modified time: 2026-03-31 15:01:53 */ import { type FC, useRef, useEffect, useState } from 'react' import clsx from 'clsx' @@ -37,11 +37,27 @@ const ChatContent: FC = ({ const prevDataLengthRef = useRef(data.length); const isScrolledToBottomRef = useRef(true); const audioRef = useRef(null) - const [playingIndex, setPlayingIndex] = useState(null) + const [expandedReasoning, setExpandedReasoning] = useState>(new Set()) + const [manualToggledReasoning, setManualToggledReasoning] = useState>(new Set()) - const handlePlay = (index: number, audio_url: string, audio_status?: string) => { - if (audio_status !== 'completed' && !audio_status) return - if (playingIndex === index) { + const toggleReasoning = (index: number) => { + setManualToggledReasoning(prev => new Set(prev).add(index)) + setExpandedReasoning(prev => { + const next = new Set(prev) + next.has(index) ? next.delete(index) : next.add(index) + return next + }) + } + + const isReasoningExpanded = (index: number) => { + if (manualToggledReasoning.has(index)) return expandedReasoning.has(index) + return !data[index]?.content + } + const [playingIndex, setPlayingIndex] = useState(null) + + const handlePlay = (audio_url: string, audio_status?: string) => { + if (audio_status !== 'completed' && typeof audio_status === 'string') return + if (playingIndex === audio_url) { audioRef.current?.pause() setPlayingIndex(null) return @@ -52,7 +68,7 @@ const ChatContent: FC = ({ const audio = new Audio(audio_url) audioRef.current = audio audio.play() - setPlayingIndex(index) + setPlayingIndex(audio_url) audio.onended = () => setPlayingIndex(null) } @@ -79,12 +95,16 @@ const ChatContent: FC = ({ } }; }, []); - + // Auto-scroll to bottom when data changes to show latest messages // When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom // When data array length changes, auto-scroll to bottom // If already scrolled to bottom, will auto-scroll to bottom useEffect(() => { + if (playingIndex && !data.some(item => item.meta_data?.audio_url === playingIndex)) { + audioRef.current?.pause() + setPlayingIndex(null) + } setTimeout(() => { if (scrollContainerRef.current) { // Auto-scroll if data length changed OR user is currently at bottom @@ -120,7 +140,7 @@ const ChatContent: FC = ({ {labelFormat(item)} } - {item.meta_data?.files && item.meta_data?.files.length > 0 && + {item.meta_data?.files && item.meta_data?.files.length > 0 && {item.meta_data?.files?.map((file) => { if (file.type.includes('image')) { return ( @@ -174,6 +194,22 @@ const ChatContent: FC = ({ 'rb:mt-1.5': labelPosition === 'top', 'rb:mb-1.5': labelPosition === 'bottom', })}> + {item.meta_data?.reasoning_content &&
+ toggleReasoning(index)} + > + {t('memoryConversation.reasoning_content')} +
+
+ {isReasoningExpanded(index) && } +
} {item.status &&
} {item.subContent && renderRuntime && renderRuntime(item, index)} {/* Render message content using Markdown component */} @@ -194,23 +230,26 @@ const ChatContent: FC = ({ key={idx} size="small" className="rb:text-[12px]!" - onClick={() => window.open(`/knowledge/${citation.knowledge_id}/document/${citation.document_id}`, '_blank')} + onClick={() => { + const params = new URLSearchParams({ documentId: citation.document_id, parentId: citation.knowledge_id }); + window.open(`/#/knowledge-base/${citation.knowledge_id}/DocumentDetails?${params}`, '_blank'); + }} >{citation.file_name} ))} } {item.meta_data?.audio_url && <> - {playingIndex !== index && item.meta_data?.audio_status === 'pending' + {playingIndex !== item.meta_data?.audio_url && item.meta_data?.audio_status === 'pending' ? - : playingIndex !== index + : playingIndex !== item.meta_data?.audio_url ? handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> + })} onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> :
handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} + onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> } diff --git a/web/src/components/Chat/ChatToolbar.tsx b/web/src/components/Chat/ChatToolbar.tsx index 3fbc0e3a..c5db0c4c 100644 --- a/web/src/components/Chat/ChatToolbar.tsx +++ b/web/src/components/Chat/ChatToolbar.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-03-17 14:22:25 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-23 17:42:38 + * @Last Modified time: 2026-03-27 17:54:47 */ // Toolbar component for chat input area, supporting file upload, audio recording, and variable configuration import { useRef, forwardRef, useImperativeHandle, type ReactNode, useEffect } from 'react' @@ -19,6 +19,7 @@ import type { UploadFileListModalRef } from '@/views/Conversation/types' import type { VariableConfigModalRef } from '@/views/Workflow/types' import type { Variable } from '@/views/Workflow/components/Properties/VariableList/types' import { getFileInfoByUrl } from '@/api/fileStorage'; +import { transform_file_type } from '@/views/Conversation/components/FileUpload' // Exposed methods via ref for parent components to access/set form state export interface ChatToolbarRef { @@ -126,7 +127,7 @@ const ChatToolbar = forwardRef(({ status: 'done', name: file_name, size: file_size, - type: content_type, + type: transform_file_type[content_type] || content_type, } : f) form.setFieldValue('files', updated) onFilesChange?.(updated) diff --git a/web/src/components/Chat/index.tsx b/web/src/components/Chat/index.tsx index 49feaf33..f7c0f32e 100644 --- a/web/src/components/Chat/index.tsx +++ b/web/src/components/Chat/index.tsx @@ -27,12 +27,14 @@ const Chat: FC = ({ fileList, fileChange, className, - renderRuntime + renderRuntime, + conversationId }) => { return (
{/* Chat content display area */} void; className?: string; renderRuntime?: (item: ChatItem, index: number) => ReactNode; + conversationId?: string | null; } /** diff --git a/web/src/components/DebounceSelect/index.tsx b/web/src/components/DebounceSelect/index.tsx new file mode 100644 index 00000000..ab8379ad --- /dev/null +++ b/web/src/components/DebounceSelect/index.tsx @@ -0,0 +1,106 @@ +import { useRef, useState, useCallback, useEffect, type FC } from 'react'; +import { Select, Spin, Avatar } from 'antd'; +import type { SelectProps, DefaultOptionType } from 'antd/es/select'; + +import { request } from '@/utils/request'; + +interface OptionType { + [key: string]: any; +} + +interface ApiResponse { + items?: T[]; +} + +export interface DebounceSelectProps extends Omit { + /** API endpoint URL — mutually exclusive with fetchOptions */ + url?: string; + /** Extra query params merged with the search keyword */ + params?: Record; + /** Key used as option value */ + valueKey?: string; + /** Key used as option label */ + labelKey?: string; + /** Key name sent to the API for the search keyword */ + searchKey?: string; + /** Custom fetch function — mutually exclusive with url */ + fetchOptions?: (search: string | null) => Promise; + /** Transform raw API items before rendering */ + format?: (items: OptionType[]) => OptionType[]; + debounceTimeout?: number; +} + +const DebounceSelect: FC = ({ + url, + params = { page: 1, pagesize: 20 }, + valueKey = 'value', + labelKey = 'label', + searchKey = 'search', + fetchOptions, + format, + debounceTimeout = 300, + ...props +}) => { + const [fetching, setFetching] = useState(false); + const [options, setOptions] = useState([]); + const fetchRef = useRef(0); + + const timerRef = useRef>(); + + // Load initial options on mount + useEffect(() => { + debounceFetcher(null); + }, []); + + const debounceFetcher = useCallback((keyword: string | null) => { + clearTimeout(timerRef.current); + timerRef.current = setTimeout(() => { + fetchRef.current += 1; + const fetchId = fetchRef.current; + setOptions([]); + setFetching(true); + + const promise: Promise = fetchOptions + ? fetchOptions(keyword) + : request + .get>(url!, { ...params, [searchKey]: keyword }) + .then((res) => { + const data: OptionType[] = Array.isArray(res) ? res : res?.items || []; + const formatted = format ? format(data) : data.map((item) => ({ + label: item[labelKey], + value: item[valueKey], + avatar: item.avatar, + raw: item, + })); + return formatted; + }); + + promise + .then((newOptions) => { + if (fetchId !== fetchRef.current) return; + setOptions(newOptions); + setFetching(false); + }) + .catch(() => setFetching(false)); + }, debounceTimeout); + }, [url, params, searchKey, fetchOptions, format, valueKey, labelKey, debounceTimeout]); + + return ( + ({ - value: type, - label: t(`application.${type}`), - }))} - placeholder={t('common.pleaseSelect')} - /> - - {t('application.aggregationStrategy')}} - className="rb:mb-0!" - > - ({ + value: type, + label: t(`application.${type}`), + }))} + placeholder={t('common.pleaseSelect')} + /> + + {t('application.aggregationStrategy')}} + className="rb:mb-0!" + > + ({ + (items as Data[]).map(item => ({ + ...item, + 'end_user.id': item.end_user?.id, + label: item.end_user?.other_name || item.end_user?.id, value: item.end_user?.id, - label: item?.name, }))} - filterOption={(inputValue, option) => option?.label?.toLowerCase().indexOf(inputValue.toLowerCase()) !== -1} - showSearch={true} - // filterOption={(inputValue, option) => option.label?.toLowerCase().indexOf(inputValue.toLowerCase()) !== -1} placeholder={t('memoryConversation.searchPlaceholder')} style={{ width: '100%', marginBottom: '16px' }} - onChange={setUserId} + onChange={(opt: DefaultOptionType) => setUserId(opt?.value as string)} variant="borderless" className="rb:bg-white rb:rounded-lg" + showSearch /> - + { headerType="borderless" headerClassName="rb:min-h-[52px]! rb:font-[MiSans-Bold] rb:font-bold" bodyClassName="rb:p-3! rb:pt-0! rb:h-[calc(100%-52px)]! rb:overflow-y-auto!" - className="rb:h-[calc(100vh-124px)]!" + className="rb:h-full!" > {loading ? diff --git a/web/src/views/MemoryExtractionEngine/components/Result.tsx b/web/src/views/MemoryExtractionEngine/components/Result.tsx index 46d05dcd..4d07aae9 100644 --- a/web/src/views/MemoryExtractionEngine/components/Result.tsx +++ b/web/src/views/MemoryExtractionEngine/components/Result.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:30:11 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 11:40:38 + * @Last Modified time: 2026-03-26 15:46:30 */ /** * Result Component @@ -267,7 +267,7 @@ const Result: FC = ({ loading, handleSave }) => { title={t('memoryExtractionEngine.exampleMemoryExtractionResults')} subTitle={t('memoryExtractionEngine.exampleMemoryExtractionResultsSubTitle')} headerClassName="rb:pb-0! rb:pt-4!" - bodyClassName="rb:h-[calc(100vh-163px)]! rb:overflow-y-auto rb:p-[16px_20px]!" + bodyClassName="rb:h-[calc(100%-50px)]! rb:overflow-y-auto rb:p-[16px_20px]!" extra={
} @@ -281,6 +281,7 @@ const Result: FC = ({ loading, handleSave }) => { onClick={handleRun} >{t('memoryExtractionEngine.debug')} } + className="rb:h-full!" > {/* } className="rb:mb-3!"> {t('memoryExtractionEngine.warning')} diff --git a/web/src/views/MemoryExtractionEngine/index.tsx b/web/src/views/MemoryExtractionEngine/index.tsx index 34596711..f7538bdc 100644 --- a/web/src/views/MemoryExtractionEngine/index.tsx +++ b/web/src/views/MemoryExtractionEngine/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:30:02 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 14:07:35 + * @Last Modified time: 2026-03-26 15:45:42 */ /** * Memory Extraction Engine Configuration Page @@ -13,7 +13,7 @@ import { type FC, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' -import { Row, Col, Space, Select, InputNumber, App, Form, Input, Flex, Tooltip } from 'antd' +import { Row, Col, Space, Select, InputNumber, App, Form, Input, Flex, Tooltip, Divider } from 'antd' import clsx from 'clsx' import Card from './components/Card' @@ -123,10 +123,10 @@ const MemoryExtractionEngine: FC = () => { - - -
- + + + +
{ {config.meaning ? {t(`memoryExtractionEngine.${config.label}`)} - {t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}}> + + {t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)} + + {config.label === 'intelligentSemanticPruningThreshold' && <> + + {t('memoryExtractionEngine.loose')} ← + + → {t('memoryExtractionEngine.strict')} + + + + + 0.0
+ |
+ {t('memoryExtractionEngine.onlyDelete')} + + + 0.3
+ |
+ {t('memoryExtractionEngine.semanticFiltering')} + + + 0.6
+ |
+ {t('memoryExtractionEngine.sceneFocus')} + + + 0.9
+ +
+ } + } + >
@@ -231,14 +267,16 @@ const MemoryExtractionEngine: FC = () => { options={config.options ? config.options.map(item => ({ ...item, label: t(`memoryExtractionEngine.${item.label}`) })) : []} /> : config.control === 'slider' - ? {t('emotionEngine.currentValue')}:} - inputClassName="rb:w-[155px]!" - /> + ? <> + {t('emotionEngine.currentValue')}:} + inputClassName="rb:w-[155px]!" + /> + : config.control === 'inputNumber' ? : config.control === 'text' @@ -259,7 +297,7 @@ const MemoryExtractionEngine: FC = () => {
- + {
- + {data.map((item) => (
{String(item.provider).charAt(0).toUpperCase() + String(item.provider).slice(1)}
- {item.tags.map(tag => {t(`modelNew.${tag}`)})} + + {item.tags.map(tag => {t(`modelNew.${tag}`)})}
+ }> + + {item.tags.map(tag => {t(`modelNew.${tag}`)})} + +
} isNeedTooltip={false} footer={ diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx index abede886..01cc0fd6 100644 --- a/web/src/views/ModelManagement/components/CustomModelModal.tsx +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:28 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 14:07:10 + * @Last Modified time: 2026-03-31 13:56:18 */ /** * Custom Model Modal @@ -11,7 +11,7 @@ */ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; -import { Form, Input, App, Checkbox, Button } from 'antd'; +import { Form, Input, App, Checkbox, Button, Row, Col } from 'antd'; import { useTranslation } from 'react-i18next'; import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types'; @@ -72,6 +72,7 @@ const CustomModelModal = forwardRef( is_vision: capability?.includes('vision') || false, is_video: capability?.includes('video') || false, is_audio: capability?.includes('audio') || false, + is_thinking: capability?.includes('thinking') || false, }); } else { setIsEdit(false); @@ -101,7 +102,7 @@ const CustomModelModal = forwardRef( form .validateFields() .then((values) => { - const { logo, type, is_vision, is_video, is_audio, is_omni, ...rest } = values; + const { logo, type, is_vision, is_video, is_audio, is_omni, is_thinking, ...rest } = values; const formData: CustomModelForm = { ...rest, type, @@ -120,6 +121,9 @@ const CustomModelModal = forwardRef( capability.push('video') } } + if (is_thinking) { + capability.push('thinking') + } formData.capability = capability formData.is_omni = is_omni @@ -238,21 +242,34 @@ const CustomModelModal = forwardRef( - {!['embedding', 'rerank'].includes(modelType as string) && - <> - - {t('modelNew.is_omni')} - - - {t('modelNew.is_vision')} - - - {t('modelNew.is_video')} - - - {t('modelNew.is_audio')} - - + {['llm', 'chat'].includes(modelType as string) && + + + + {t('modelNew.is_omni')} + + + + + {t('modelNew.is_vision')} + + + + + {t('modelNew.is_video')} + + + + + {t('modelNew.is_audio')} + + + + + {t('modelNew.is_thinking')} + + + } diff --git a/web/src/views/ModelManagement/components/ModelListDetail.tsx b/web/src/views/ModelManagement/components/ModelListDetail.tsx index a3b9bbc0..ffe7a17f 100644 --- a/web/src/views/ModelManagement/components/ModelListDetail.tsx +++ b/web/src/views/ModelManagement/components/ModelListDetail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:45 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:28:07 + * @Last Modified time: 2026-03-27 18:06:23 */ /** * Model List Detail Drawer @@ -12,7 +12,7 @@ import { useState, useImperativeHandle, forwardRef, useRef, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { Button, Switch, Row, Col, Space, Tooltip } from 'antd' +import { Button, Switch, Row, Col, Space, Tooltip, Popover } from 'antd' import type { ProviderModelItem, ModelListItem, ModelListDetailRef, MultiKeyConfigModalRef } from '../types'; import RbDrawer from '@/components/RbDrawer'; @@ -136,11 +136,19 @@ const ModelListDetail = forwardRef(({ - {t(`modelNew.${item.type}`)} - {item.api_keys.length}{t('modelNew.apiKeyNum')} - {item.capability?.filter(item => item !=='video').map(vo => {t(`modelNew.${vo}`)})} - } + subTitle={ + + {t(`modelNew.${item.type}`)} + {item.api_keys.length}{t('modelNew.apiKeyNum')} + {item.capability?.map(vo => {t(`modelNew.${vo}`)})} + }> + + {t(`modelNew.${item.type}`)} + {item.api_keys.length}{t('modelNew.apiKeyNum')} + {item.capability?.map(vo => {t(`modelNew.${vo}`)})} + + } avatarUrl={getLogoUrl(item.logo)} avatar={
diff --git a/web/src/views/ModelManagement/index.tsx b/web/src/views/ModelManagement/index.tsx index 461f7cd5..0dcef68f 100644 --- a/web/src/views/ModelManagement/index.tsx +++ b/web/src/views/ModelManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:50:05 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:28:07 + * @Last Modified time: 2026-03-26 15:51:08 */ /** * Model Management Main Page @@ -84,8 +84,8 @@ const tabKeys = ['group', 'list', 'square'] } return ( - - + <> + -
+
{activeTab === 'group' && } {activeTab === 'list' && customModelModalRef.current?.handleClose() } />} {activeTab === 'square' && } @@ -145,7 +145,7 @@ const tabKeys = ['group', 'list', 'square'] ref={customModelModalRef} refresh={handleRefresh} /> - + ) } diff --git a/web/src/views/ModelManagement/types.ts b/web/src/views/ModelManagement/types.ts index 1662775f..cafac4b3 100644 --- a/web/src/views/ModelManagement/types.ts +++ b/web/src/views/ModelManagement/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:50:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:28:10 + * @Last Modified time: 2026-03-31 15:48:02 */ /** * Type definitions for Model Management @@ -295,7 +295,8 @@ export interface CustomModelForm { is_video?: boolean; is_audio?: boolean; is_omni?: boolean; - capability?: string[]; + is_thinking?: boolean; + capability?: Capability[]; } /** @@ -324,7 +325,7 @@ export interface BaseRef { modelListDetailRefresh?: () => void; } -export type Capability = 'vision' | 'audio' | 'video'; +export type Capability = 'vision' | 'audio' | 'video' | 'thinking'; export interface Model { name: string; type: string; diff --git a/web/src/views/ModelManagement/utils.ts b/web/src/views/ModelManagement/utils.ts index 23e1e4f3..8180e194 100644 --- a/web/src/views/ModelManagement/utils.ts +++ b/web/src/views/ModelManagement/utils.ts @@ -2,18 +2,18 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:50:22 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 14:03:13 + * @Last Modified time: 2026-03-25 18:35:13 */ /** * Utility functions for Model Management */ -import bedrockIcon from '@/assets/images/model/bedrock.svg' +import bedrockIcon from '@/assets/images/model/bedrock.png' import dashscopeIcon from '@/assets/images/model/dashscope.png' import gpustackIcon from '@/assets/images/model/gpustack.png' -import ollamaIcon from '@/assets/images/model/ollama.svg' -import openaiIcon from '@/assets/images/model/openai.svg' -import xinferenceIcon from '@/assets/images/model/xinference.svg' +import ollamaIcon from '@/assets/images/model/ollama.png' +import openaiIcon from '@/assets/images/model/openai.png' +import xinferenceIcon from '@/assets/images/model/xinference.png' import volcanoIcon from '@/assets/images/model/volcano.png' /** diff --git a/web/src/views/Ontology/index.tsx b/web/src/views/Ontology/index.tsx index cd599b97..7b98efea 100644 --- a/web/src/views/Ontology/index.tsx +++ b/web/src/views/Ontology/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:10:15 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 13:38:59 + * @Last Modified time: 2026-03-27 15:03:09 */ import { type FC, useState, useRef } from 'react'; import type { MenuInfo } from 'rc-menu/lib/interface'; @@ -150,13 +150,13 @@ const Ontology: FC = () => { items: [ { key: 'edit', - icon:
, + icon:
, label: t('common.edit'), onClick: (e: MenuInfo) => handleEdit(item, e), }, { key: 'delete', - icon:
, + icon:
, label: t('common.delete'), onClick: (e: MenuInfo) => handleDelete(item, e), }, @@ -164,7 +164,7 @@ const Ontology: FC = () => { }} placement="bottomRight" > -
e.stopPropagation()} className="rb:cursor-pointer rb:size-6 rb:bg-[url('@/assets/images/common/more.svg')] rb:hover:bg-[url('@/assets/images/common/more_hover.svg')]">
+
e.stopPropagation()} className="rb:cursor-pointer rb:size-5.5 rb:bg-[url('@/assets/images/common/more.svg')] rb:hover:bg-[url('@/assets/images/common/more_hover.svg')]">
} diff --git a/web/src/views/Ontology/pages/Detail.tsx b/web/src/views/Ontology/pages/Detail.tsx index fdf33105..cf26867e 100644 --- a/web/src/views/Ontology/pages/Detail.tsx +++ b/web/src/views/Ontology/pages/Detail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:10:20 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:16:23 + * @Last Modified time: 2026-03-26 18:55:37 */ import { type FC, useEffect, useState, useRef } from 'react' import { useParams, useNavigate } from 'react-router-dom'; @@ -100,58 +100,60 @@ const Detail: FC = () => { return ( <> - - {data.scene_name} - {data.is_system_default ? {t('common.default')} : undefined} - -
-
- } - extra={ - {data.is_system_default ? undefined : ( - - - )} - navigate(-1)}> -
- {t('common.return')} -
-
} - /> + + + {data.scene_name} + {data.is_system_default ? {t('common.default')} : undefined} + +
+
+ } + extra={ + {data.is_system_default ? undefined : ( + + + )} + navigate(-1)}> +
+ {t('common.return')} +
+
} + /> -
- - - setQuery({ class_name: value })} - className="rb:w-full!" - /> - - - - - {data.items?.map(item => ( - - handleDelete(item)} - >
)} - > - -
{item.class_description}
-
- - - ))} +
+ + + setQuery({ class_name: value })} + className="rb:w-full!" + /> + - -
+ + + {data.items?.map(item => ( + + handleDelete(item)} + >
)} + > + +
{item.class_description}
+
+ + + ))} + + +
+ = ({ return (
{title}
-
{desc}
+
{desc}
) } diff --git a/web/src/views/Prompt/index.tsx b/web/src/views/Prompt/index.tsx index 98718b12..13c09042 100644 --- a/web/src/views/Prompt/index.tsx +++ b/web/src/views/Prompt/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:44:15 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-20 13:52:09 + * @Last Modified time: 2026-03-27 15:14:58 */ /** * Prompt Editor Component @@ -188,12 +188,12 @@ const Prompt: FC = () => { } data={chatList || []} @@ -243,13 +243,14 @@ const Prompt: FC = () => { { {values?.current_prompt ? form.setFieldValue('current_prompt', value)} /> - : + : } diff --git a/web/src/views/Prompt/pages/History.tsx b/web/src/views/Prompt/pages/History.tsx index 7953af0f..19c033ed 100644 --- a/web/src/views/Prompt/pages/History.tsx +++ b/web/src/views/Prompt/pages/History.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:44:04 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 12:18:09 + * @Last Modified time: 2026-03-27 15:52:44 */ /** * Prompt History Component @@ -12,8 +12,7 @@ import React, { useRef, type MouseEvent } from 'react'; import { useTranslation } from 'react-i18next'; import { useNavigate } from 'react-router-dom'; -import { Tooltip, App, Flex, Form, Dropdown } from 'antd'; -import { DashOutlined } from '@ant-design/icons'; +import { Space, App, Flex, Form } from 'antd'; import type { HistoryQuery, HistoryItem, PromptDetailRef } from '../types'; import RbCard from '@/components/RbCard/Card' @@ -104,28 +103,33 @@ const History: React.FC = () => { renderItem={(item) => ( {item.title}} - headerClassName='rb:h-[38px]! rb:pt-3!' + title={item.title} + headerClassName='rb:min-h-[46px]!' headerType="borderless" + bodyClassName="rb:px-3! rb:py-0!" > - handleClick(key, item) - }} - > - - -
{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}
-
+
+ + +
{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}
+ + +
handleClick('detail', item)} + >
+
handleClick('edit', item)} + >
+
handleClick('delete', item)} + >
+
+
)} + heightClass="rb:h-[calc(100vh-126px)]!" /> { } return ( - - + + @@ -179,7 +179,7 @@ const SelfReflectionEngine: React.FC = () => { } headerType="borderless" headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold" - className="rb:h-[calc(100vh-76px)]!" + className="rb:h-full!" bodyClassName="rb:h-[calc(100%-54px)] rb:overflow-y-auto! rb:p-4! rb:pt-0!" >
{
- - + + @@ -346,7 +346,7 @@ const SelfReflectionEngine: React.FC = () => { )} } - +
); diff --git a/web/src/views/SpaceConfig/index.tsx b/web/src/views/SpaceConfig/index.tsx index 86fef42a..f2cbe05c 100644 --- a/web/src/views/SpaceConfig/index.tsx +++ b/web/src/views/SpaceConfig/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 17:48:03 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 17:01:59 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-30 11:36:24 */ /** * Space Configuration Page @@ -15,9 +15,7 @@ import { useTranslation } from 'react-i18next'; import type { SpaceConfigData } from './types' import { getWorkspaceModels, updateWorkspaceModels } from '@/api/workspaces' -import { getModelListUrl } from '@/api/models' -import CustomSelect from '@/components/CustomSelect' -import RbAlert from '@/components/RbAlert'; +import ModelSelect from '@/components/ModelSelect'; const SpaceConfig: FC = () => { const { t } = useTranslation(); @@ -63,7 +61,9 @@ const SpaceConfig: FC = () => { } return ( -
+
+
{t('menu.spaceConfig')}
+
{t('space.configAlert')}
{pageLoading ? :
{ layout="vertical" > - - - - {t('space.configAlert')} - - - - +
}
diff --git a/web/src/views/SpaceManagement/index.tsx b/web/src/views/SpaceManagement/index.tsx index 0df3c653..781ee903 100644 --- a/web/src/views/SpaceManagement/index.tsx +++ b/web/src/views/SpaceManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:48:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 15:33:38 + * @Last Modified time: 2026-03-26 14:43:20 */ /** * Space Management Page @@ -12,7 +12,7 @@ import React, { useEffect, useState, useRef } from 'react'; import { useNavigate } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; -import { List, Button, Flex, Space as AntSpace, Tooltip } from 'antd'; +import { Button, Flex, Space as AntSpace, Tooltip, Row, Col } from 'antd'; import type { Space, SpaceModalRef } from './types'; import SpaceModal from './components/SpaceModal'; @@ -68,11 +68,12 @@ const SpaceManagement: React.FC = () => { {t('space.createSpace')} - ( - + + {data.map(item => ( + { } > - - )} - className="rb:h-[calc(100vh-124px)] rb:overflow-y-auto rb:overflow-x-hidden" - /> + + ))} + ReactNo return ( <> - ( - + + {data.map((item) => ( + @@ -88,14 +88,14 @@ const Custom = forwardRef ReactNo items: [ { key: 'edit', - icon:
, + icon:
, label: t('common.edit'), onClick: () => handleEdit(item), }, { key: 'delete', className: 'rb:text-[#FF5D34]!', - icon:
, + icon:
, label: t('common.delete'), onClick: () => handleDeleteService(item), }, @@ -103,13 +103,12 @@ const Custom = forwardRef ReactNo }} placement="bottomRight" > -
+
} isNeedTooltip={false} > - {item.tags?.length > 0 ? @@ -142,10 +141,9 @@ const Custom = forwardRef ReactNo - - )} - className="rb:h-[calc(100vh-178px)] rb:overflow-y-auto rb:overflow-x-hidden" - /> + + ))} + {/* 添加服务弹窗组件 */} diff --git a/web/src/views/ToolManagement/Inner.tsx b/web/src/views/ToolManagement/Inner.tsx index 689ce43e..b88428b0 100644 --- a/web/src/views/ToolManagement/Inner.tsx +++ b/web/src/views/ToolManagement/Inner.tsx @@ -1,6 +1,5 @@ import React, { useState, useRef, useEffect, type ReactNode } from 'react'; import { - List, Flex, Space, Tooltip, @@ -68,13 +67,14 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode; keyword?: s } return ( -
+ <> - ( - + + {data.map((item) => ( + @@ -86,7 +86,7 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode; keyword?: s
handleEdit(item)} /> @@ -130,21 +130,20 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode; keyword?: s : item.config_data.tool_class === 'JsonTool' - ? + ?
{t('tool.jsonEg')}
{InnerConfigData[item.config_data.tool_class].eg} - : -
{t('configStatus')}
- {t(`tool.${item.status}_desc`)} - + : +
{t('configStatus')}
+ {t(`tool.${item.status}_desc`)} + } - - )} - className="rb:h-[calc(100vh-178px)] rb:overflow-y-auto rb:overflow-x-hidden" - /> + + ))} + ReactNode; keyword?: s ref={innerToolModalRef} refreshTable={getData} /> -
+ ); }; diff --git a/web/src/views/ToolManagement/Market.tsx b/web/src/views/ToolManagement/Market.tsx index 3e2ca456..2c2cb6a9 100644 --- a/web/src/views/ToolManagement/Market.tsx +++ b/web/src/views/ToolManagement/Market.tsx @@ -434,7 +434,7 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => footer={ {mcp.publisher && {mcp.publisher.startsWith('@') ? mcp.publisher : `@${mcp.publisher}`}} {mcp.view_count && -
+
{mcp.view_count.toLocaleString()}
}
} diff --git a/web/src/views/ToolManagement/Mcp.tsx b/web/src/views/ToolManagement/Mcp.tsx index eb9b45ad..d24eee3b 100644 --- a/web/src/views/ToolManagement/Mcp.tsx +++ b/web/src/views/ToolManagement/Mcp.tsx @@ -1,11 +1,11 @@ import { useState, useRef, useEffect, forwardRef, useImperativeHandle, type ReactNode } from 'react'; import { App, - List, Space, Tooltip, Dropdown, Flex, + Row, Col, } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -84,11 +84,12 @@ const Mcp = forwardRef ReactNode; ke return ( <> - ( - + + {data.map((item) => ( + @@ -103,20 +104,20 @@ const Mcp = forwardRef ReactNode; ke items: [ { key: 'edit', - icon:
, + icon:
, label: t('common.edit'), onClick: () => handleEdit(item), }, { key: 'link', - icon:
, + icon:
, label: t('tool.testLink'), onClick: () => handleTestConnection(item), }, { key: 'delete', className: 'rb:text-[#FF5D34]!', - icon:
, + icon:
, label: t('common.delete'), onClick: () => handleDeleteService(item), }, @@ -124,7 +125,7 @@ const Mcp = forwardRef ReactNode; ke }} placement="bottomRight" > -
+
} @@ -137,12 +138,12 @@ const Mcp = forwardRef ReactNode; ke
{t('tool.last_health_check')}: {formatDateTime(item.config_data?.last_health_check)}
- + - - )} - className="rb:h-[calc(100vh-124px)] rb:overflow-y-auto rb:overflow-x-hidden" - /> + + ))} + + {/* 添加服务弹窗组件 */} diff --git a/web/src/views/UserManagement/index.tsx b/web/src/views/UserManagement/index.tsx index 4f7aea92..ed09994f 100644 --- a/web/src/views/UserManagement/index.tsx +++ b/web/src/views/UserManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:51:08 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 16:45:18 + * @Last Modified time: 2026-03-26 14:53:41 */ /** * User Management Page @@ -142,9 +142,9 @@ const UserManagement: React.FC = () => { ]; return ( -
+
-
{t('user.userList')}
+
{t('user.userList')}
@@ -159,7 +159,7 @@ const UserManagement: React.FC = () => { columns={columns} rowKey="id" isScroll={true} - scrollY="calc(100vh - 256px)" + scrollY="calc(100vh - 248px)" /> (false); - const [data, setData] = useState([]); const [form] = Form.useForm() - const search = Form.useWatch(['search'], form) + const keyword = Form.useWatch(['keyword'], form) - /** Fetch user memory list */ - useEffect(() => { - getData() - }, []); + const scrollListRef = useRef(null) - /** Get data from API */ - const getData = () => { - setLoading(true) - getUserMemoryList().then((res) => { - setData(res as Data[] || []) - }) - .finally(() => { - setLoading(false) - }) - } /** Navigate to user memory detail */ const handleViewDetail = (id: string | number) => { switch (storageType) { @@ -64,25 +49,12 @@ export default function UserMemory() { navigate(`/memory`) } - /** Filter data by search term */ - const filterData = useMemo(() => { - if (search && search.trim() !== '') { - return data.filter((item) => { - const { end_user } = item as Data; - const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id - return name?.includes(search) - }) - } - - return data - }, [search, data]) - return (
- +
- {loading ? - - : filterData.length > 0 ? ( - - {filterData.map((item, index) => { - const { end_user, memory_num, memory_config } = item as Data; - const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id - return ( - - -
{name[0]}
- -
{name || '-'}
- } - headerType="border" - headerClassName="rb:h-[48px]! rb:mx-4!" - bodyClassName="rb:py-3! rb:px-4!" - className="rb:cursor-pointer" - onClick={() => handleViewDetail(end_user.id)} - > - - - - - - - - -
- - {t('userMemory.memory_config_name')} -
-
-
{memory_config?.memory_config_name || '-'}
-
-
+ + + ref={scrollListRef} + url={userMemoryListUrl} + query={{ keyword }} + column={3} + renderItem={(item) => { + const { end_user, memory_num, memory_config } = item as Data; + const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id + return ( + +
{name[0]}
+ +
{name || '-'}
+ } + headerType="border" + headerClassName="rb:h-[48px]! rb:mx-4!" + bodyClassName="rb:py-3! rb:px-4!" + className="rb:cursor-pointer" + onClick={() => handleViewDetail(end_user.id)} + > + + + - ) - })} - - ) : - } + + + +
+ +
+ + {t('userMemory.memory_config_name')} +
+
+
{memory_config?.memory_config_name || '-'}
+
+ + ) + }} + />
); } \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/Neo4j.tsx b/web/src/views/UserMemoryDetail/Neo4j.tsx index 6dde8d69..51be7c8d 100644 --- a/web/src/views/UserMemoryDetail/Neo4j.tsx +++ b/web/src/views/UserMemoryDetail/Neo4j.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:57:26 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 14:28:34 + * @Last Modified time: 2026-03-26 18:59:53 */ /** * Neo4j User Memory Detail View @@ -75,7 +75,7 @@ const Neo4j: FC = () => { return (
setSelectedKey(null)}> - + { }) } return ( - - + +
{name?.[0]}
@@ -174,7 +175,7 @@ const Rag: FC = () => {
- +
diff --git a/web/src/views/UserMemoryDetail/components/AboutMe.tsx b/web/src/views/UserMemoryDetail/components/AboutMe.tsx index f38ef30a..72b9af61 100644 --- a/web/src/views/UserMemoryDetail/components/AboutMe.tsx +++ b/web/src/views/UserMemoryDetail/components/AboutMe.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:34:23 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-11 15:03:05 + * @Last Modified time: 2026-03-27 11:09:52 */ /** * About Me Component @@ -67,12 +67,12 @@ const AboutMe = forwardRef(({ className }, title={t('userMemory.aboutMe')} headerClassName="rb:min-h-[46px]!! rb:font-medium!" className={clsx("rb:bg-[#FFFFFF]! rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.13)]! rb:absolute! rb:w-100 rb:top-29 rb:left-26", className)} - bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-176px)] rb:overflow-y-auto!" + bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-186px)]! rb:overflow-y-auto!" > {loading ? : Object.keys(data).filter(key => data[key] !== null).length > 0 - ? <> + ?
{data.user_summary &&
{data.user_summary} @@ -95,7 +95,7 @@ const AboutMe = forwardRef(({ className }, {data.one_sentence && {data.one_sentence} } - +
: } diff --git a/web/src/views/UserMemoryDetail/components/ActivationMetricsPieCard.tsx b/web/src/views/UserMemoryDetail/components/ActivationMetricsPieCard.tsx index 759a4bd7..211d5968 100644 --- a/web/src/views/UserMemoryDetail/components/ActivationMetricsPieCard.tsx +++ b/web/src/views/UserMemoryDetail/components/ActivationMetricsPieCard.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:34:16 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 11:36:02 + * @Last Modified time: 2026-03-27 11:22:10 */ import { type FC } from 'react' import { useTranslation } from 'react-i18next' @@ -32,7 +32,7 @@ const ActivationMetricsPieCard: FC = ({ chartData className="rb:h-full!" > {loading - ? + ? : = ({ src, fileName, fileSize }) => { >
-
+
diff --git a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx index b8322c87..ccfbc14d 100644 --- a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx +++ b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx @@ -65,7 +65,7 @@ const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => voi }, [id]) if (loading) { - return + return
diff --git a/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx b/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx index 641a8af6..c209274b 100644 --- a/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx +++ b/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx @@ -2,36 +2,64 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:34:04 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-20 11:04:52 + * @Last Modified time: 2026-03-31 15:35:13 */ -import { type FC } from 'react' +import { type FC, useState } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' +import { Divider, Flex } from 'antd' +import clsx from 'clsx' import RbCard from '@/components/RbCard/Card' import PageScrollList from '@/components/PageScrollList' import Markdown from '@/components/Markdown' import { getRagContentUrl } from '@/api/memory' +interface DataItem { + role: 'user' | 'assistant'; + content: string; +} + const ConversationMemory: FC = () => { const { t } = useTranslation() const { id } = useParams() + const [total, setTotal] = useState(0) return ( {t('userMemory.conversationMemory')}} + headerType="borderless" + headerClassName="rb:min-h-[54px]! rb:pt-0! rb:mb-0!" + bodyClassName="rb:p-4! rb:pt-0! rb:pb-1! rb:h-[calc(100%-54px)]!" + className="rb:h-full!" + extra={
{t('userMemory.totalRagMemory')}: {total}
} > - + url={getRagContentUrl} query={{ end_user_id: id }} column={1} - renderItem={(item: string) => ( -
- + gutter={0} + onTotalChange={setTotal} + renderItem={(item, index) => ( +
+ {index !== 0 && } + +
+
+
+ {item.role === 'assistant' ? t('userMemory.assistant') : t('userMemory.user')} +
+ +
+
)} className="rb:h-full!" diff --git a/web/src/views/UserMemoryDetail/components/EndUserProfile.tsx b/web/src/views/UserMemoryDetail/components/EndUserProfile.tsx index 0533efc8..c689bf72 100644 --- a/web/src/views/UserMemoryDetail/components/EndUserProfile.tsx +++ b/web/src/views/UserMemoryDetail/components/EndUserProfile.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:33:30 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 17:55:02 + * @Last Modified time: 2026-03-27 11:11:09 */ /** * End User Profile Component @@ -85,7 +85,7 @@ const EndUserProfile = forwardRef(({ cla } headerClassName="rb:min-h-[46px]!! rb:font-medium!" className={clsx("rb:bg-[#FFFFFF]! rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.13)]! rb:absolute! rb:w-80 rb:top-29 rb:left-26", className)} - bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-176px)] rb:overflow-auto" + bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-186px)] rb:overflow-auto" > {loading ? diff --git a/web/src/views/UserMemoryDetail/components/Habits.tsx b/web/src/views/UserMemoryDetail/components/Habits.tsx index f5ccab03..90b687ff 100644 --- a/web/src/views/UserMemoryDetail/components/Habits.tsx +++ b/web/src/views/UserMemoryDetail/components/Habits.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:33:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 14:05:10 + * @Last Modified time: 2026-03-27 11:18:47 */ import { useEffect, useState, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' @@ -83,7 +83,7 @@ const Habits = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { headerType="borderless" headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold" bodyClassName="rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)] rb:overflow-y-auto!" - className="rb:h-[calc(100vh-88px)]!" + className="rb:h-full!" > {loading ? diff --git a/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx b/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx index 8d5457cd..c10719de 100644 --- a/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx +++ b/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:32:47 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-05 18:29:29 + * @Last Modified time: 2026-03-27 11:11:35 */ /** * Interest Distribution Component @@ -77,7 +77,7 @@ const InterestDistribution: FC<{ className?: string; }> = ({ className }) => { title={t('userMemory.interestDistribution')} headerClassName="rb:min-h-[46px]!! rb:font-medium!" className={clsx("rb:bg-[#FFFFFF]! rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.13)]! rb:absolute! rb:w-100 rb:top-29 rb:left-26", className)} - bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-176px)] rb:overflow-auto" + bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-186px)] rb:overflow-auto" > {loading ? diff --git a/web/src/views/UserMemoryDetail/components/MemoryInsight.tsx b/web/src/views/UserMemoryDetail/components/MemoryInsight.tsx index 578c8823..4b7274b1 100644 --- a/web/src/views/UserMemoryDetail/components/MemoryInsight.tsx +++ b/web/src/views/UserMemoryDetail/components/MemoryInsight.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:32:41 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-05 18:35:01 + * @Last Modified time: 2026-03-27 11:11:46 */ /** * Memory Insight Component @@ -66,7 +66,7 @@ const MemoryInsight = forwardRef(({ c title={t('userMemory.memoryInsight')} headerClassName="rb:min-h-[46px]!! rb:font-medium!" className={clsx("rb:bg-[#FFFFFF]! rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.13)]! rb:absolute! rb:w-100 rb:top-29 rb:left-26", className)} - bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-176px)] rb:overflow-auto" + bodyClassName="rb:px-5! rb:pb-5! rb:pt-3.75! rb:max-h-[calc(100vh-186px)] rb:overflow-auto" > {loading ? diff --git a/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx b/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx index 3330f5bc..69f252e0 100644 --- a/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx +++ b/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:32:23 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:09:53 + * @Last Modified time: 2026-03-27 14:57:34 */ import { type FC, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -92,7 +92,7 @@ const PerceptualLastInfo: FC = () => { setData(response) setLoading(false) if (response.file_path) { - fetch(response.file_path, { method: 'HEAD' }) + fetch(response.file_path, { method: 'GET' }) .then(r => { const bytes = Number(r.headers.get('content-length')) if (!bytes) return @@ -119,7 +119,7 @@ const PerceptualLastInfo: FC = () => { headerType="borderless" headerClassName="rb:min-h-[50px]! rb:font-[MiSans-Bold] rb:font-bold" bodyClassName="rb:p-4! rb:pt-0! rb:h-[calc(100%-50px)] rb:overflow-y-auto" - className="rb:h-[calc(100vh-88px)]! rb:w-full!" + className="rb:h-full! rb:w-full!" > {Object.keys(KEYS).map(key => ( diff --git a/web/src/views/UserMemoryDetail/components/RecentTrendsLineCard.tsx b/web/src/views/UserMemoryDetail/components/RecentTrendsLineCard.tsx index d3170114..18de43a0 100644 --- a/web/src/views/UserMemoryDetail/components/RecentTrendsLineCard.tsx +++ b/web/src/views/UserMemoryDetail/components/RecentTrendsLineCard.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:32:07 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 11:49:29 + * @Last Modified time: 2026-03-27 11:23:11 */ import { type FC, useRef } from 'react' import { useTranslation } from 'react-i18next' @@ -72,7 +72,7 @@ const RecentTrendsLineCard: FC = ({ chartData, series className="rb:h-full!" > {loading - ? + ? : !chartData || chartData.length === 0 ? : { headerType="borderless" headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold" bodyClassName="rb:pl-5! rb:pt-0! rb:pr-3! rb:pb-4! rb:h-[calc(100%-54px)] rb:overflow-y-auto" - className="rb:h-[calc(100vh-88px)]!" + className="rb:h-full!" > {loading ? diff --git a/web/src/views/UserMemoryDetail/pages/EpisodicDetail.tsx b/web/src/views/UserMemoryDetail/pages/EpisodicDetail.tsx index ca6d456e..647cc104 100644 --- a/web/src/views/UserMemoryDetail/pages/EpisodicDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/EpisodicDetail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-01-08 19:46:02 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 15:03:50 + * @Last Modified time: 2026-03-27 11:20:40 */ import { type FC, useEffect, useState } from 'react' import clsx from 'clsx' @@ -148,15 +148,15 @@ const EpisodicDetail: FC = () => { } return ( - - + + {t('episodicDetail.curResult')} ({data.total || 0}{t('episodicDetail.unix')})
} headerType="borderless" - className="rb:h-[calc(100vh-88px)]!" + className="rb:h-full!" headerClassName="rb:min-h-[38px]! rb:pt-3! rb:mb-0!" bodyClassName="rb:p-3! rb:pb-0! rb:h-[calc(100%-38px)]!" > @@ -231,11 +231,11 @@ const EpisodicDetail: FC = () => { }
- + diff --git a/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx b/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx index 4795ef45..185ed02d 100644 --- a/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-01-10 17:35:17 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 15:05:06 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-27 11:19:38 */ import { type FC, useEffect, useState, useRef } from 'react' import { useTranslation } from 'react-i18next' @@ -132,14 +132,14 @@ const ExplicitDetail: FC = () => { return () => { chartInstance.current?.dispose(); chartInstance.current = null } }, [data.semantic_memories]) return ( - - + + {loading ? @@ -163,13 +163,13 @@ const ExplicitDetail: FC = () => { } - + {loading ? diff --git a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx index 2510aaa9..04391107 100644 --- a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx @@ -12,6 +12,7 @@ import { Row, Col, Progress, App, Table } from 'antd' import RbCard from '@/components/RbCard/Card' import { getForgetStats, + getForgetPendingNodesUrl, } from '@/api/memory' import type { ForgetData } from '../types' import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard' @@ -19,6 +20,7 @@ import RecentTrendsLineCard from '../components/RecentTrendsLineCard' import { formatDateTime } from '@/utils/format' import StatusTag from '@/components/StatusTag' import ForgetRefreshModal from '../components/ForgetRefreshModal'; +import RbTable from '@/components/Table' /** Maps node type keys to StatusTag colour presets for the pending-nodes table. */ const statusTagColors: Record = { @@ -191,7 +193,9 @@ const ForgetDetail = forwardRef((_props, ref) => { bodyClassName="rb:p-3! rb:py-0! rb:h-[calc(100%-54px)]" className="rb:h-full!" > - { render: (activation_value) => {activation_value} }, ]} - pagination={{ - pageSize: 5, - showQuickJumper: true, - className: 'rb:mt-5! rb:mb-5.75!' - }} className="table-header-has-bg" /> diff --git a/web/src/views/UserMemoryDetail/pages/GraphDetail.tsx b/web/src/views/UserMemoryDetail/pages/GraphDetail.tsx index f8282687..19f49ff0 100644 --- a/web/src/views/UserMemoryDetail/pages/GraphDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/GraphDetail.tsx @@ -113,13 +113,13 @@ const GraphDetail = forwardRef((_props, ref) => { } /> - + @@ -127,13 +127,13 @@ const GraphDetail = forwardRef((_props, ref) => { - + ((_props, ref) => { }))} onChange={(key: string) => setActiveTab(key)} /> -
+
{timelineLoading ? : !activeContent || activeContent.length === 0 diff --git a/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx index 2448ff90..8dd71dc2 100644 --- a/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-01-08 19:46:02 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 11:56:55 + * @Last Modified time: 2026-03-27 11:18:50 */ import { forwardRef, useImperativeHandle, useRef, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -83,41 +83,39 @@ const ImplicitDetail = forwardRef<{ handleRefresh: () => void; }, { refresh: () } return ( -
- -
- - + + + + -
- {activeTab === 'preferences' - ? - :
- - -
- } -
-
- - - - - - +
+ {activeTab === 'preferences' + ? + :
+ + +
+ } +
+ + +
+ + + ) }) export default ImplicitDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx b/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx index 8871d9c3..ce76c899 100644 --- a/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-01-08 19:46:02 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 15:09:12 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-27 11:13:19 */ import { type FC } from 'react' import { Row, Col } from 'antd' @@ -22,11 +22,11 @@ import Timeline from '../components/Timeline' const PerceptualDetail: FC = () => { return ( - - + + - + diff --git a/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx index 8b38aa54..07a27720 100644 --- a/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-01-08 19:46:02 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 15:09:49 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-27 11:17:22 */ import { type FC, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -87,8 +87,8 @@ const ShortTermDetail: FC = () => { } return ( - - + +
{(['retrieval_number', 'entity', 'long_term_number'] as const).map(key => ( @@ -115,10 +115,10 @@ const ShortTermDetail: FC = () => { )} headerType="borderless" headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold" - bodyClassName="rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)] rb:overflow-y-auto!" - className="rb:h-[calc(100vh-183px)]!" + bodyClassName="rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)]" + className="rb:h-[calc(100%-94px)]!" > - + {loading ? : !data.short_term || data.short_term.length === 0 @@ -189,7 +189,7 @@ const ShortTermDetail: FC = () => { -
+ ( {t('shortTermDetail.longTermTitle')} @@ -200,7 +200,7 @@ const ShortTermDetail: FC = () => { headerType="borderless" headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold" bodyClassName="rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)] rb:overflow-y-auto!" - className="rb:h-[calc(100vh-88px)]!" + className="rb:h-full!" > {loading diff --git a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx index d8feab1f..419cf3b1 100644 --- a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2025-12-19 16:54:52 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 15:06:29 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-27 11:35:37 */ import { forwardRef, useImperativeHandle, useRef } from 'react' import { Row, Col } from 'antd'; @@ -46,8 +46,8 @@ const StatementDetail = forwardRef<{ handleRefresh: () => void },{ refresh: () = handleRefresh })); return ( - - + + @@ -60,7 +60,7 @@ const StatementDetail = forwardRef<{ handleRefresh: () => void },{ refresh: () = - + diff --git a/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx index bfe928be..2e288988 100644 --- a/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-01-12 14:42:02 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 11:55:36 + * @Last Modified time: 2026-03-27 11:15:05 */ import { type FC, useEffect, useState, useMemo, useRef } from 'react' import clsx from 'clsx' @@ -155,14 +155,14 @@ const WorkingDetail: FC = () => { : data.length === 0 ? :( - - + +
{ {selected && <> -
+
{timeRange}
@@ -226,13 +226,13 @@ const WorkingDetail: FC = () => { }
- + {detailLoading ? diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index da2d9620..dad34fc9 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:10:56 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 13:57:40 + * @Last Modified time: 2026-03-27 17:30:47 */ /** * Workflow Chat Component @@ -41,7 +41,9 @@ import type { ChatToolbarRef } from '@/components/Chat/ChatToolbar' import Runtime from './Runtime'; import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types'; -const Chat = forwardRef(({ appId, graphRef, data }, ref) => { +const Chat = forwardRef(({ + appId, graphRef, features +}, ref) => { const { t } = useTranslation() const { message: messageApi } = App.useApp() const toolbarRef = useRef(null) @@ -58,7 +60,6 @@ const Chat = forwardRef(null) const [fileList, setFileList] = useState([]) const [message, setMessage] = useState(undefined) - const [features, setFeatures] = useState({} as FeaturesConfigForm) /** * Opens the chat drawer and loads workflow variables from the start node @@ -67,10 +68,6 @@ const Chat = forwardRef { - if (data?.features && open) setFeatures(data.features) - }, [open, data?.features]) - useEffect(() => { if (open && toolbarReady) { getVariables() @@ -434,7 +431,7 @@ const Chat = forwardRef diff --git a/web/src/views/Workflow/components/NodeLibrary.tsx b/web/src/views/Workflow/components/NodeLibrary.tsx index b825650a..a7b06fd1 100644 --- a/web/src/views/Workflow/components/NodeLibrary.tsx +++ b/web/src/views/Workflow/components/NodeLibrary.tsx @@ -10,7 +10,7 @@ const NodeLibrary: FC<{ collapsed: boolean; handleToggle: () => void }> = ({ col const { t } = useTranslation() return ( -
@@ -27,9 +27,9 @@ const NodeLibrary: FC<{ collapsed: boolean; handleToggle: () => void }> = ({ col 'rb:min-h-[52px]!': collapsed })} className="rb:h-full! rb:hover:shadow-none!" - bodyClassName={clsx('rb:overflow-y-auto! rb:h-[calc(100vh-126px)]! rb:pt-0! rb:pb-3!', { - 'rb:px-0!': collapsed, - 'rb:px-3!': !collapsed + bodyClassName={clsx('rb:overflow-y-auto! rb:pt-0! rb:pb-3!', { + 'rb:px-0! rb:h-[calc(100%-52px)]!': collapsed, + 'rb:px-3! rb:h-[calc(100%-42px)]!': !collapsed })} > @@ -70,7 +70,7 @@ const NodeLibrary: FC<{ collapsed: boolean; handleToggle: () => void }> = ({ col key={nodeIndex} align="center" gap={8} - className="rb:rounded-xl rb:p-2! rb-border rb:cursor-pointer rb:hover:border rb:hover:border-[#171719]!" + className="rb:rounded-xl rb:p-2! rb:border rb:border-[#EBEBEB] rb:cursor-pointer rb:hover:border rb:hover:border-[#171719]!" draggable onDragStart={(e) => { e.dataTransfer.setData('application/reactflow', node.type); @@ -87,8 +87,6 @@ const NodeLibrary: FC<{ collapsed: boolean; handleToggle: () => void }> = ({ col } - -
); }; diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index 9b9d2236..dd0ab23d 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-09 18:31:30 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-06 11:43:58 + * @Last Modified time: 2026-03-30 11:55:10 */ import { useState } from 'react'; import { Popover, Flex } from 'antd'; @@ -173,7 +173,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { align="center" justify="center" gap={4} - className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#DFE4ED] rb:flex rb:items-center rb:justify-center', { + className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', { 'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected, 'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected })} diff --git a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx index 12ae6ca0..516b5125 100644 --- a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx +++ b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx @@ -48,7 +48,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { return (
diff --git a/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx b/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx index 0f963adc..4a29531f 100644 --- a/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx +++ b/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx @@ -3,7 +3,7 @@ import type { ReactShapeConfig } from '@antv/x6-react-shape'; const GroupStartNode: ReactShapeConfig['component'] = () => { return ( -
+
); diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index b8c2ea0c..29c683cc 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -122,7 +122,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { return (
diff --git a/web/src/views/Workflow/components/Nodes/NormalNode.tsx b/web/src/views/Workflow/components/Nodes/NormalNode.tsx index 340e95dc..12e89cca 100644 --- a/web/src/views/Workflow/components/Nodes/NormalNode.tsx +++ b/web/src/views/Workflow/components/Nodes/NormalNode.tsx @@ -12,7 +12,7 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => { return (
diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 2cc0c3c5..13ad6b98 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-09 18:30:28 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 11:11:56 + * @Last Modified time: 2026-03-30 15:14:02 */ import { useEffect, useState } from 'react'; import { Popover } from 'antd'; @@ -20,13 +20,15 @@ const PortClickHandler: React.FC = ({ graph }) => { const [sourceNode, setSourceNode] = useState(null); const [sourcePort, setSourcePort] = useState(''); const [tempElement, setTempElement] = useState(null); + const [edgeInsertion, setEdgeInsertion] = useState(null); useEffect(() => { const handlePortClick = (event: CustomEvent) => { - const { node, port, element, rect } = event.detail; + const { node, port, element, rect, edgeInsertion } = event.detail; setSourceNode(node); setSourcePort(port); setTempElement(element); + setEdgeInsertion(edgeInsertion || null); setPopoverPosition({ x: rect.left, y: rect.top }); setPopoverVisible(true); }; @@ -72,15 +74,47 @@ const PortClickHandler: React.FC = ({ graph }) => { const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort); const sourcePortGroup = sourcePortInfo?.group || sourcePort; - // If add-node position exists, use it; otherwise calculate new position + // Calculate new node position let newX, newY; - if (addNodePosition) { + if (edgeInsertion) { + // Edge insertion: place new node on the same row as target, between source and target + const targetBBox = edgeInsertion.targetCell.getBBox(); + const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width); + const requiredSpace = nodeWidth + horizontalSpacing * 4; + + // New node x: right after source + spacing + newX = sourceBBox.x + sourceBBox.width + horizontalSpacing; + // Same row as target node + newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2; + + // If not enough space, shift target and all downstream nodes to the right + if (gap < requiredSpace) { + const shiftX = requiredSpace - gap; + const visited = new Set(); + const shiftDownstream = (cell: any) => { + const cellId = cell.id; + if (visited.has(cellId)) return; + visited.add(cellId); + const pos = cell.getPosition(); + cell.setPosition(pos.x + shiftX, pos.y); + // Recursively shift nodes connected from right ports + graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => { + const tId = e.getTargetCellId(); + if (tId && !visited.has(tId)) { + const tCell = graph.getCellById(tId); + if (tCell?.isNode()) shiftDownstream(tCell); + } + }); + }; + shiftDownstream(edgeInsertion.targetCell); + } + } else if (addNodePosition) { newX = addNodePosition.x; newY = addNodePosition.y; } else { // Determine node placement direction based on port position if (sourcePortGroup === 'left') { - // Left port: add node to the left + // Left port: add node to the left newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing; newY = sourceBBox.y; } else { @@ -91,7 +125,7 @@ const PortClickHandler: React.FC = ({ graph }) => { // Check if position overlaps with existing nodes (only consider connected nodes) const checkOverlap = (x: number, y: number) => { - // Get nodes connected to the source node + // Get nodes connected to the source node const connectedNodes = new Set(); graph.getConnectedEdges(sourceNode).forEach((edge: any) => { const sourceId = edge.getSourceCellId(); @@ -108,7 +142,7 @@ const PortClickHandler: React.FC = ({ graph }) => { y + nodeHeight < bbox.y || y > bbox.y + bbox.height); }); }; - + // If position is occupied, search downward for empty space while (checkOverlap(newX, newY)) { newY += nodeHeight + verticalSpacing; @@ -140,28 +174,51 @@ const PortClickHandler: React.FC = ({ graph }) => { } } + // Edge insertion: remove old edge immediately before creating new edges + if (edgeInsertion) { + const { edge: oldEdge } = edgeInsertion; + if (oldEdge.id && graph.getCellById(oldEdge.id)) { + graph.removeCell(oldEdge.id); + } else { + graph.removeEdge(oldEdge); + } + } + // Create edge connection setTimeout(() => { - const targetPorts = newNode.getPorts(); - let targetPort; - - if (sourcePortGroup === 'left') { + const newPorts = newNode.getPorts(); + + if (edgeInsertion) { + // Edge insertion: create source→new and new→target edges + const { targetCell, targetPort: origTargetPort } = edgeInsertion; + const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left'; + const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right'; + graph.addEdge({ + source: { cell: sourceNode.id, port: sourcePort }, + target: { cell: newNode.id, port: newLeftPort }, + ...edgeAttrs + }); + graph.addEdge({ + source: { cell: newNode.id, port: newRightPort }, + target: { cell: targetCell.id, port: origTargetPort }, + ...edgeAttrs + }); + setEdgeInsertion(null); + } else if (sourcePortGroup === 'left') { // Connect from left port to new node's right side - targetPort = targetPorts.find((port: any) => port.group === 'right')?.id || 'right'; + const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right'; graph.addEdge({ source: { cell: newNode.id, port: targetPort }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs - // zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0 }); } else { // Connect from right port to new node's left side - targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left'; + const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left'; graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: targetPort }, ...edgeAttrs - // zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0 }); } diff --git a/web/src/views/Workflow/components/Properties/VariableList/index.tsx b/web/src/views/Workflow/components/Properties/VariableList/index.tsx index 30b96505..6ff545eb 100644 --- a/web/src/views/Workflow/components/Properties/VariableList/index.tsx +++ b/web/src/views/Workflow/components/Properties/VariableList/index.tsx @@ -95,7 +95,7 @@ const VariableList: FC = ({ {config.sys?.map((vo, index) => - + sys.{vo.name} {vo.type} diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index bb80ed33..66b59075 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:39:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 15:08:02 + * @Last Modified time: 2026-03-27 11:30:44 */ import { type FC, useEffect, useState, useMemo } from "react"; import clsx from 'clsx' @@ -431,7 +431,7 @@ const Properties: FC = ({ } return ( -
+
@@ -452,7 +452,7 @@ const Properties: FC = ({ headerType="borderless" headerClassName={clsx("rb:font-[MiSans-Bold] rb:font-bold rb:min-h-[48px]!")} className="rb:h-full! rb:hover:shadow-none!" - bodyClassName={clsx('rb:overflow-y-auto! rb:h-[calc(100vh-131px)]! rb:px-3! rb:pt-0! rb:pb-3!')} + bodyClassName={clsx('rb:overflow-y-auto! rb:h-[calc(100%-48px)]! rb:px-3! rb:pt-0! rb:pb-3!')} >
diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index cadff647..92773191 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:06:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 11:11:46 + * @Last Modified time: 2026-03-31 10:08:26 */ import LoopNode from './components/Nodes/LoopNode'; import NormalNode from './components/Nodes/NormalNode'; @@ -33,6 +33,7 @@ import assignerIcon from '@/assets/images/workflow/assigner.svg' import memoryReadIcon from '@/assets/images/workflow/memory-read.svg' import memoryWriteIcon from '@/assets/images/workflow/memory-write.svg' import unknownIcon from '@/assets/images/workflow/unknown.svg' +import documentExtractorIcon from '@/assets/images/workflow/document-extractor.svg' import { memoryConfigListUrl } from '@/api/memory' import type { NodeLibrary } from './types' @@ -473,8 +474,7 @@ export const nodeLibrary: NodeLibrary[] = [ }, } }, - { - type: "document-extractor", icon: codeExecutionIcon, + { type: "document-extractor", icon: documentExtractorIcon, config: { file_selector: { type: 'variableList', @@ -693,8 +693,59 @@ export const portArgs = { x: nodeWidth, y: portItemArgsY } const defaultPortGroup = { position: { name: 'absolute' }, - markup: portMarkup, - attrs: portAttrs + markup: [ + { tagName: 'rect', selector: 'body' }, + { tagName: 'circle', selector: 'hoverBody' }, + { tagName: 'text', selector: 'label' }, + ], + attrs: { + body: { + width: 1, + height: 8, + x: -1, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + }, + hoverBody: { + r: 6, + cy: 2, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + opacity: 0, + }, + label: { + text: '+', + fontSize: 12, + fontWeight: 'bold', + fill: '#FFFFFF', + textAnchor: 'middle', + textVerticalAnchor: 'middle', + pointerEvents: 'none', + y: '0.15em', + opacity: 0, + }, + }, +} + +const leftPortGroup = { + position: { name: 'absolute' }, + markup: [{ tagName: 'rect', selector: 'body' }], + attrs: { + body: { + width: 1, + height: 8, + x: -1, + y: -4, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + }, + }, } /** @@ -703,7 +754,7 @@ const defaultPortGroup = { */ export const defaultAbsolutePortGroups = { right: defaultPortGroup, - left: defaultPortGroup, + left: leftPortGroup, } /** * Default port items for standard nodes @@ -797,7 +848,7 @@ export const graphNodeLibrary: Record = { height: 28, shape: 'add-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [{ group: 'left', args: { x: 0, y: 18 }}], }, }, @@ -824,7 +875,7 @@ export const graphNodeLibrary: Record = { height: 28, shape: 'add-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [{ group: 'left', args: { x: 0, y: 14 } }], }, }, @@ -833,7 +884,7 @@ export const graphNodeLibrary: Record = { height: 76, shape: 'normal-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [defaultPortItems[0]], }, }, @@ -877,11 +928,74 @@ export const edgeAttrs = { line: { stroke: edge_color, strokeWidth: edge_width, - targetMarker: { - name: 'block', - width: 4, - height: 4, + targetMarker: null, + sourceMarker: null, + }, + }, +} + +/** + * Edge hover tool: circular "+" button shown at midpoint on hover + */ +export const edgeHoverTool = { + name: 'button', + args: { + markup: [ + { + tagName: 'circle', + selector: 'button', + attrs: { + r: 6, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + cursor: 'pointer', + }, }, + { + tagName: 'text', + textContent: '+', + selector: 'icon', + attrs: { + fontSize: 12, + fontWeight: 'bold', + fill: '#FFFFFF', + textAnchor: 'middle', + textVerticalAnchor: 'middle', + pointerEvents: 'none', + y: '0.3em', + }, + }, + ], + distance: 0.5, + offset: { x: 0, y: 0 }, + onClick({ e, cell: edge }: any) { + e.stopPropagation(); + const graph = edge.model?.graph; + if (!graph) return; + const sourceCell = graph.getCellById(edge.getSourceCellId()); + const targetCell = graph.getCellById(edge.getTargetCellId()); + const sourcePort = edge.getSourcePortId(); + const targetPort = edge.getTargetPortId(); + if (!sourceCell || !targetCell) return; + const rect = (e.target as HTMLElement).getBoundingClientRect(); + const tempDiv = document.createElement('div'); + tempDiv.style.position = 'fixed'; + tempDiv.style.left = rect.left + 'px'; + tempDiv.style.top = rect.top + 'px'; + tempDiv.style.width = '1px'; + tempDiv.style.height = '1px'; + tempDiv.style.zIndex = '9999'; + document.body.appendChild(tempDiv); + window.dispatchEvent(new CustomEvent('port:click', { + detail: { + node: sourceCell, + port: sourcePort, + element: tempDiv, + rect, + edgeInsertion: { edge, sourceCell, targetCell, sourcePort, targetPort } + } + })); }, }, } \ No newline at end of file diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 197e4e4b..c427788b 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 15:01:52 + * @Last Modified time: 2026-03-31 11:13:23 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -12,7 +12,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from ' import { register } from '@antv/x6-react-shape'; import type { PortMetadata } from '@antv/x6/lib/model/port'; -import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; +import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types'; import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' import { useUser } from '@/store/user'; @@ -72,6 +72,7 @@ export interface UseWorkflowGraphReturn { handleAddNotes: () => void; handleSaveFeaturesConfig: (value: FeaturesConfigForm) => void; + features?: FeaturesConfigForm; } /** @@ -100,6 +101,7 @@ export const useWorkflowGraph = ({ const [isHandMode, setIsHandMode] = useState(true); const [config, setConfig] = useState(null); const [chatVariables, setChatVariables] = useState([]) + const featuresRef = useRef(undefined) useEffect(() => { getConfig() @@ -121,6 +123,7 @@ export const useWorkflowGraph = ({ }) setChatVariables(initChatVariables) setConfig({ ...rest, variables: initChatVariables }) + featuresRef.current = rest.features onFeaturesLoad?.(rest.features) }) } @@ -438,6 +441,7 @@ export const useWorkflowGraph = ({ setTimeout(() => { if (graphRef.current) { graphRef.current.centerContent() + graphRef.current.getNodes().forEach(node => node.toFront()); } }, 200) } @@ -519,7 +523,9 @@ export const useWorkflowGraph = ({ * @param edge - Clicked edge */ const edgeClick = ({ edge }: { edge: Edge }) => { + clearEdgeSelect(); edge.setAttrByPath('line/stroke', edge_selected_color); + edge.setData({ ...edge.getData(), isSelected: true }); clearNodeSelect(); }; /** @@ -544,6 +550,7 @@ export const useWorkflowGraph = ({ */ const clearEdgeSelect = () => { graphRef.current?.getEdges().forEach(e => { + e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }); e.setAttrByPath('line/stroke', edge_color); e.setAttrByPath('line/strokeWidth', edge_width); }); @@ -716,6 +723,7 @@ export const useWorkflowGraph = ({ }; const nodePortClickEvent = ({ e, node, port }: { e: MouseEvent, node: Node, port: string }) => { e.stopPropagation(); + e.preventDefault(); const portElement = e.target as HTMLElement; const rect = portElement.getBoundingClientRect(); @@ -835,15 +843,25 @@ export const useWorkflowGraph = ({ // 1. If both nodes have parent IDs, they must be same to connect // 2. If both have no parent ID, can connect normally // 3. If one has parent, one doesn't, cannot connect - console.log('sourceParentId', sourceParentId, targetParentId) if (sourceParentId && targetParentId) { // Child nodes under same parent can connect to each other - return sourceParentId === targetParentId; + if (sourceParentId !== targetParentId) return false; } else if (sourceParentId || targetParentId) { // One has parent, one doesn't, cannot connect return false; } - + + // Prevent duplicate connections between same ports + const sourcePortId = sourceMagnet?.getAttribute('port') ?? sourceMagnet?.closest('[port]')?.getAttribute('port'); + const targetPortId = targetMagnet?.getAttribute('port') ?? targetMagnet?.closest('[port]')?.getAttribute('port'); + const duplicate = graphRef.current?.getEdges().some(e => + e.getSourceCellId() === sourceCell?.id && + e.getTargetCellId() === targetCell?.id && + e.getSourcePortId() === sourcePortId && + e.getTargetPortId() === targetPortId + ); + if (duplicate) return false; + return true; }, }, @@ -878,12 +896,24 @@ export const useWorkflowGraph = ({ }); // Use plugins setupPlugins(); - // Listen to edge mouseleave event + // Listen to edge mouseenter event: show hover style and add button + graphRef.current.on('edge:mouseenter', ({ edge }: { edge: Edge }) => { + setTimeout(() => { + edge.addTools([edgeHoverTool]); + }, 0) + }); + // Listen to edge mouseleave event: revert style and remove add button graphRef.current.on('edge:mouseleave', ({ edge }: { edge: Edge }) => { - if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { - edge.setAttrByPath('line/stroke', edge_color); - edge.setAttrByPath('line/strokeWidth', edge_width); + const data = edge.getData(); + if (!data?.isSelected) { + if (data?.isNodeHover) { + edge.setAttrByPath('line/stroke', edge_selected_color); + } else { + edge.setAttrByPath('line/stroke', edge_color); + edge.setAttrByPath('line/strokeWidth', edge_width); + } } + edge.removeTools(); }); // Listen to node selection event graphRef.current.on('node:click', nodeClick); @@ -891,13 +921,134 @@ export const useWorkflowGraph = ({ graphRef.current.on('edge:click', edgeClick); // Listen to port click event graphRef.current.on('node:port:click', nodePortClickEvent); + // Port hover: show circle style on right ports + graphRef.current.on('node:port:mouseenter', ({ node, port }) => { + console.log('node:port:mouseenter', port) + if (!port) return; + const portData = node.getPort(port); + if (portData?.group !== 'right') return; + node.toFront(); + node.setPortProp(port, 'attrs/body/opacity', 0); + node.setPortProp(port, 'attrs/hoverBody/opacity', 1); + node.setPortProp(port, 'attrs/label/opacity', 1); + }); + graphRef.current.on('node:port:mouseleave', ({ node, port }) => { + if (!port) return; + const portData = node.getPort(port); + if (portData?.group !== 'right') return; + node.setPortProp(port, 'attrs/body/opacity', 1); + node.setPortProp(port, 'attrs/hoverBody/opacity', 0); + node.setPortProp(port, 'attrs/label/opacity', 0); + }); // Listen to canvas click event, cancel selection graphRef.current.on('blank:click', blankClick); + // Node hover: highlight connected edges + graphRef.current.on('node:mouseenter', ({ node }) => { + graphRef.current?.getEdges().forEach(edge => { + const view = graphRef.current?.findViewByCell(edge); + view?.removeTools(); + if (!edge.getData()?.isSelected && edge.getAttrByPath('line/stroke') === edge_selected_color) { + edge.setAttrByPath('line/stroke', edge_color); + } + }); + graphRef.current?.getConnectedEdges(node).forEach(edge => { + if (!edge.getData()?.isSelected) { + edge.setAttrByPath('line/stroke', edge_selected_color); + edge.setData({ ...edge.getData(), isNodeHover: true }); + } + }); + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 0); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 1); + node.setPortProp(p.id!, 'attrs/label/opacity', 1); + }); + }); + graphRef.current.on('node:mouseleave', ({ node }) => { + graphRef.current?.getConnectedEdges(node).forEach(edge => { + if (!edge.getData()?.isSelected) { + edge.setAttrByPath('line/stroke', edge_color); + edge.setData({ ...edge.getData(), isNodeHover: false }); + } + }); + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 1); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 0); + node.setPortProp(p.id!, 'attrs/label/opacity', 0); + }); + }); // Listen to zoom event graphRef.current.on('scale', scaleEvent); // Listen to node move event graphRef.current.on('node:moved', nodeMoved); graphRef.current.on('node:removed', blankClick) + // When edge connected, bring connected nodes' ports to front + graphRef.current.on('edge:connected', ({ isNew }) => { + graphRef.current?.getNodes().forEach(node => node.toFront()); + // Reset any port hover state left from dragging + if (isNew) { + graphRef.current?.getNodes().forEach(node => { + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 1); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 0); + node.setPortProp(p.id!, 'attrs/label/opacity', 0); + }); + }); + } + }); + + // During edge dragging, manually detect port hover since the dragging edge blocks mouse events + let lastHoveredPort: { node: Node; portId: string } | null = null; + graphRef.current.on('edge:mousemove', ({ e }: { e: MouseEvent }) => { + if (!graphRef.current) return; + const { clientX, clientY } = e; + let found: { node: Node; portId: string } | null = null; + + for (const node of graphRef.current.getNodes()) { + for (const port of node.getPorts().filter(p => p.group === 'right')) { + const portView = graphRef.current.findViewByCell(node); + if (!portView) continue; + const portEl = (portView as any).findPortElem(port.id!, 'body') as SVGElement | null; + if (!portEl) continue; + const rect = portEl.getBoundingClientRect(); + const hitRadius = 16; + const cx = rect.left + rect.width / 2; + const cy = rect.top + rect.height / 2; + if (Math.abs(clientX - cx) <= hitRadius && Math.abs(clientY - cy) <= hitRadius) { + found = { node, portId: port.id! }; + break; + } + } + if (found) break; + } + + if (found?.node.id !== lastHoveredPort?.node.id || found?.portId !== lastHoveredPort?.portId) { + // Leave previous + if (lastHoveredPort) { + const { node, portId } = lastHoveredPort; + node.setPortProp(portId, 'attrs/body/opacity', 1); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 0); + node.setPortProp(portId, 'attrs/label/opacity', 0); + } + // Enter new + if (found) { + const { node, portId } = found; + node.toFront(); + node.setPortProp(portId, 'attrs/body/opacity', 0); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 1); + node.setPortProp(portId, 'attrs/label/opacity', 1); + } + lastHoveredPort = found; + } + }); + graphRef.current.on('edge:mouseup', () => { + if (lastHoveredPort) { + const { node, portId } = lastHoveredPort; + node.setPortProp(portId, 'attrs/body/opacity', 1); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 0); + node.setPortProp(portId, 'attrs/label/opacity', 0); + lastHoveredPort = null; + } + }); // Listen to copy keyboard event graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent); // Listen to paste keyboard event @@ -1016,6 +1167,7 @@ export const useWorkflowGraph = ({ const params = { ...config, + features: featuresRef.current, variables: chatVariables.map(v => { const { defaultValue, ...cleanV } = v return { @@ -1173,7 +1325,7 @@ export const useWorkflowGraph = ({ saveWorkflowConfig(config.app_id, params as WorkflowConfig) .then((res) => { if (flag) { - message.success(t('common.saveSuccess')) + message.success({ content: t('common.saveSuccess'), duration: 1 }) } resolve(res) }).catch(error => { @@ -1208,7 +1360,8 @@ export const useWorkflowGraph = ({ }); } const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => { - setConfig(prev => prev ? { ...prev, features: value } as WorkflowConfig : prev) + featuresRef.current = value + onFeaturesLoad?.(value) } return { @@ -1230,6 +1383,7 @@ export const useWorkflowGraph = ({ chatVariables, setChatVariables, handleAddNotes, - handleSaveFeaturesConfig + handleSaveFeaturesConfig, + features: featuresRef.current, }; }; diff --git a/web/src/views/Workflow/index.tsx b/web/src/views/Workflow/index.tsx index d5d46a55..b698e857 100644 --- a/web/src/views/Workflow/index.tsx +++ b/web/src/views/Workflow/index.tsx @@ -34,7 +34,8 @@ const Workflow = forwardRef { @@ -55,12 +56,13 @@ const Workflow = forwardRef +
{/* 左侧节点面板 */} @@ -99,6 +101,7 @@ const Workflow = forwardRef