Merge remote-tracking branch 'origin/release/v0.2.9' into develop
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||
|
||||
# Thread-local storage for connection pools.
|
||||
# Each thread (and each forked process) gets its own pool to avoid
|
||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||
# and stale connections after fork in --pool=prefork.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||
"""Return a Redis client whose connection pool is bound to the current
|
||||
thread, process **and** event loop.
|
||||
|
||||
The pool is recreated when:
|
||||
- The PID changes (fork, Celery --pool=prefork)
|
||||
- The thread has no pool yet (Celery --pool=threads)
|
||||
- The previously-cached event loop has been closed (Celery tasks call
|
||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||
"""
|
||||
current_pid = os.getpid()
|
||||
cached_loop = getattr(_thread_local, "loop", None)
|
||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||
|
||||
if not hasattr(_thread_local, "pool") \
|
||||
or getattr(_thread_local, "pid", None) != current_pid \
|
||||
or loop_stale:
|
||||
_thread_local.pid = current_pid
|
||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||
try:
|
||||
_thread_local.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
_thread_local.loop = None
|
||||
_thread_local.pool = ConnectionPool.from_url(
|
||||
_REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
else:
|
||||
# 设置永久键值
|
||||
await aio_redis.set(key, val)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
@@ -57,7 +57,6 @@ def list_apps(
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -66,7 +65,7 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- 当提供 api_key 参数时,查找该 API Key 关联的应用
|
||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
@@ -74,23 +73,34 @@ def list_apps(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
|
||||
if api_key:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == api_key,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
ids = str(matched_id) if matched_id else ""
|
||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||
if search:
|
||||
search = search.strip()
|
||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||
if len(search) >= 10:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == search,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if matched_id:
|
||||
# 找到 API Key,直接返回关联的应用
|
||||
ids = str(matched_id)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
if app_ids:
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
# 返回标准分页格式
|
||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
# ids 为空时,返回空列表
|
||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||
return success(data=PageData(page=meta, items=[]))
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
|
||||
@@ -3,17 +3,16 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_log_service import AppLogService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
@@ -25,52 +24,35 @@ def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- 支持按 user_id 筛选
|
||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||
- 按最新更新时间倒序排列
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversations, total = log_service.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
total = int(db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(db.scalars(stmt).all())
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={"app_id": str(app_id), "total": total, "page": page}
|
||||
)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@@ -86,44 +68,22 @@ def get_app_log_detail(
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 查询会话(确保属于该应用和工作空间)
|
||||
conversation = db.scalars(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
)
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = list(db.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
).all())
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversation = log_service.get_conversation_detail(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
|
||||
return success(data=detail)
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
@@ -259,8 +260,41 @@ def get_conversation(
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
file_ids = []
|
||||
message_file_id_map = {}
|
||||
|
||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||
for idx, m in enumerate(messages):
|
||||
if m.role == "assistant" and m.meta_data:
|
||||
audio_url = m.meta_data.get("audio_url")
|
||||
if not audio_url:
|
||||
continue
|
||||
try:
|
||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# audio_url 无法解析为 UUID,标记为 unknown
|
||||
m.meta_data["audio_status"] = "unknown"
|
||||
continue
|
||||
|
||||
file_ids.append(file_id)
|
||||
message_file_id_map[idx] = file_id
|
||||
|
||||
# 批量查询所有相关的 FileMetadata
|
||||
file_status_map = {}
|
||||
if file_ids:
|
||||
file_metas = (
|
||||
db.query(FileMetadata)
|
||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||
.all()
|
||||
)
|
||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||
|
||||
# 第二次遍历:将查询结果映射回消息
|
||||
for idx, file_id in message_file_id_map.items():
|
||||
m = messages[idx]
|
||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
@@ -320,6 +354,16 @@ async def chat(
|
||||
other_id=other_id,
|
||||
original_user_id=user_id
|
||||
)
|
||||
|
||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||
if not new_end_user.memory_config_id:
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||
if memory_config_id:
|
||||
new_end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
db.refresh(new_end_user)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
# appid = share.app_id
|
||||
|
||||
@@ -91,7 +91,7 @@ async def chat(
|
||||
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
|
||||
@@ -111,6 +111,18 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@@ -135,7 +147,6 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
|
||||
@@ -151,11 +151,6 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
@@ -279,5 +274,21 @@ async def write(
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def message_has_files(message: "ConversationMessage") -> bool:
|
||||
"""检查消息是否包含文件。
|
||||
|
||||
Args:
|
||||
message: 待检查的消息对象
|
||||
|
||||
Returns:
|
||||
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||
"""
|
||||
return message.files and len(message.files) > 0
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
|
||||
@@ -128,7 +140,7 @@ class SemanticPruner:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
5. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
@@ -482,6 +494,11 @@ class SemanticPruner:
|
||||
"""
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
@@ -549,6 +566,11 @@ class SemanticPruner:
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||
if self._is_filler_message(m):
|
||||
@@ -801,6 +823,12 @@ class SemanticPruner:
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||
continue
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
|
||||
@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[PerceptualEdge],
|
||||
dict
|
||||
list[DialogData]
|
||||
]:
|
||||
"""
|
||||
运行完整的知识提取流水线(优化版:并行执行)
|
||||
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
) = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return (
|
||||
dialogue_nodes,
|
||||
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
if first_alias:
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,不做任何修改。
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序)
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
"""
|
||||
USER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in USER_NAMES:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
|
||||
return aliases
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return aliases
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
|
||||
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
||||
|
||||
决策规则:
|
||||
- 为空 → 用本次对话第一个别名
|
||||
- 为空或为占位名称 → 用本次对话第一个别名
|
||||
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
||||
- 否则 → 保持不变(返回 None)
|
||||
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
if not current or not current.strip():
|
||||
return current_aliases[0].strip() if current_aliases else None
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
return neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[DialogData],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edges,
|
||||
dedup_statement_entity_edges,
|
||||
dedup_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
final_entity_nodes = dedup_entity_nodes
|
||||
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
|
||||
final_entity_entity_edges = dedup_entity_entity_edges
|
||||
else:
|
||||
# 正式模式:执行完整的两阶段去重
|
||||
result_tuple = await dedup_layers_and_merge_and_return(
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = await dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
final_entity_nodes,
|
||||
_,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = result_tuple
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||
|
||||
result_tuple = (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||
|
||||
@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
|
||||
{% if language == "zh" %}
|
||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||
- 用户的真实姓名:放入 aliases
|
||||
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||
- 示例:
|
||||
* "我叫李明" → name="用户", aliases=["李明"]
|
||||
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||
{% else %}
|
||||
- User entity name field: use "User" or "I"
|
||||
- User's real name: put in aliases
|
||||
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||
- Examples:
|
||||
* "I'm John" → name="User", aliases=["John"]
|
||||
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||
{% endif %}
|
||||
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
bucket_name: str,
|
||||
connect_timeout: int = 30,
|
||||
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
|
||||
):
|
||||
"""
|
||||
Initialize the OSSStorage backend.
|
||||
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: The Aliyun access key ID.
|
||||
access_key_secret: The Aliyun access key secret.
|
||||
bucket_name: The name of the OSS bucket.
|
||||
connect_timeout: Connection timeout in seconds (default: 30).
|
||||
multipart_threshold: File size threshold for multipart upload (default: 10MB).
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.bucket_name = bucket_name
|
||||
self.multipart_threshold = multipart_threshold
|
||||
|
||||
try:
|
||||
auth = oss2.Auth(access_key_id, access_key_secret)
|
||||
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
||||
# 设置超时和重试
|
||||
self.bucket = oss2.Bucket(
|
||||
auth,
|
||||
endpoint,
|
||||
bucket_name,
|
||||
connect_timeout=connect_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
||||
)
|
||||
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB per part
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
|
||||
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
||||
return file_key
|
||||
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e.message}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
|
||||
) -> int:
|
||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||
buf = io.BytesIO()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
upload_id = None
|
||||
|
||||
try:
|
||||
# 收集流数据
|
||||
total_size = 0
|
||||
async for chunk in stream:
|
||||
if not chunk:
|
||||
continue
|
||||
buf.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
content = buf.getvalue()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||
return len(content)
|
||||
|
||||
if not content:
|
||||
raise StorageUploadError(
|
||||
message="Empty stream content",
|
||||
file_key=file_key,
|
||||
)
|
||||
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
|
||||
return total_size
|
||||
|
||||
except OssError as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
finally:
|
||||
buf.close()
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error downloading file {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e.message}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error deleting file {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e.message}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status != 200:
|
||||
if response.status not in (200, 202):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -190,7 +190,9 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
# MCP SSE 协议:POST 请求返回 200 或 202 均为正常
|
||||
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
|
||||
if response.status not in (200, 202):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -205,7 +207,7 @@ class SimpleMCPClient:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
if response.status not in (200, 202):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from app.repositories.neo4j.create_indexes import create_all_indexes
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
|
||||
await create_all_indexes()
|
||||
logger.info("应用程序启动完成")
|
||||
|
||||
|
||||
yield
|
||||
# 应用关闭事件
|
||||
logger.info("应用程序正在关闭")
|
||||
|
||||
@@ -19,9 +19,12 @@ class User(Base):
|
||||
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
|
||||
|
||||
# SSO 外部关联字段
|
||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
||||
external_id = Column(String(100), nullable=True) # 外部用户 ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 用户联系方式
|
||||
phone = Column(String(50), nullable=True) # 用户电话
|
||||
|
||||
# 用户语言偏好
|
||||
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
||||
|
||||
|
||||
@@ -199,6 +199,96 @@ class ConversationRepository:
|
||||
)
|
||||
return conversations, total
|
||||
|
||||
def list_app_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
查询应用日志会话列表(带分页和过滤)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
is_draft: 是否草稿会话(None 表示不过滤)
|
||||
page: 页码(从 1 开始)
|
||||
pagesize: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Conversation], int]: (会话列表,总数)
|
||||
"""
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True)
|
||||
)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
# Calculate total number of records
|
||||
total = int(self.db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Listed app conversations successfully",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"returned": len(conversations),
|
||||
"total": total
|
||||
}
|
||||
)
|
||||
return conversations, total
|
||||
|
||||
def get_conversation_for_app_log(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Conversation:
|
||||
"""
|
||||
查询应用日志的会话详情
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
Conversation: 会话对象
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当会话不存在时
|
||||
"""
|
||||
logger.info(f"Fetching conversation for app log: {conversation_id}")
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True)
|
||||
)
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
|
||||
if not conversation:
|
||||
logger.warning(f"Conversation not found: {conversation_id}")
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
logger.info(f"Conversation fetched successfully: {conversation_id}")
|
||||
return conversation
|
||||
|
||||
def soft_delete_conversation_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -290,6 +380,34 @@ class MessageRepository:
|
||||
self.db.add(message)
|
||||
return message
|
||||
|
||||
def get_messages_by_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
) -> list[Message]:
|
||||
"""
|
||||
查询会话的所有消息(按时间正序)
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
List[Message]: 消息列表
|
||||
"""
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Fetched messages for conversation",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_message_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
|
||||
@@ -132,6 +132,82 @@ class EndUserRepository:
|
||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_or_create_end_user_with_config(
|
||||
self,
|
||||
app_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID,
|
||||
other_id: str,
|
||||
memory_config_id: Optional[uuid.UUID] = None,
|
||||
other_name: Optional[str] = None
|
||||
) -> EndUser:
|
||||
"""获取或创建终端用户,并在单次事务中关联记忆配置。
|
||||
|
||||
与 get_or_create_end_user 类似,但额外支持在创建/获取时
|
||||
一并设置 memory_config_id,避免多次提交。
|
||||
|
||||
Args:
|
||||
app_id: 应用ID(可为 None)
|
||||
workspace_id: 工作空间ID
|
||||
other_id: 第三方ID
|
||||
memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置)
|
||||
other_name: 用户名称(用于创建 EndUserInfo)
|
||||
|
||||
Returns:
|
||||
EndUser: 终端用户对象(已关联记忆配置)
|
||||
"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(
|
||||
EndUser.workspace_id == workspace_id,
|
||||
EndUser.other_id == other_id
|
||||
)
|
||||
.order_by(EndUser.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user:
|
||||
db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}")
|
||||
if app_id is not None:
|
||||
end_user.app_id = app_id
|
||||
if memory_config_id and not end_user.memory_config_id:
|
||||
end_user.memory_config_id = memory_config_id
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
# 创建新用户
|
||||
end_user = EndUser(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
)
|
||||
self.db.add(end_user)
|
||||
self.db.flush()
|
||||
|
||||
end_user_info = EndUserInfo(
|
||||
end_user_id=end_user.id,
|
||||
other_name=other_name or "",
|
||||
aliases=[],
|
||||
meta_data={}
|
||||
)
|
||||
self.db.add(end_user_info)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
|
||||
db_logger.info(
|
||||
f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, "
|
||||
f"memory_config_id={memory_config_id}"
|
||||
)
|
||||
return end_user
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)
|
||||
|
||||
@@ -515,6 +591,51 @@ class EndUserRepository:
|
||||
)
|
||||
raise
|
||||
|
||||
def batch_update_memory_config_id_by_app(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 更新的终端用户数量
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败时抛出
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.app_id == app_id)
|
||||
.values(memory_config_id=memory_config_id)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
db_logger.info(
|
||||
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(
|
||||
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def count_by_memory_config_id(
|
||||
self,
|
||||
memory_config_id: uuid.UUID
|
||||
|
||||
@@ -1,62 +1,47 @@
|
||||
import asyncio
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Full-Text Indexes (for keyword search)")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: statementsFulltext")
|
||||
""")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
# """)
|
||||
|
||||
# 创建 Entities 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: entitiesFulltext")
|
||||
""")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: chunksFulltext")
|
||||
""")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: summariesFulltext")
|
||||
|
||||
""")
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: communitiesFulltext")
|
||||
|
||||
print("\nFull-text indexes created successfully with BM25 support.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating full-text indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
@@ -65,12 +50,7 @@ async def create_vector_indexes():
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Vector Indexes (for embedding search)")
|
||||
print("=" * 70)
|
||||
print("Note: Adjust vector.dimensions if using different embedding model")
|
||||
print(" Current setting: 1024 dimensions (for bge-m3)")
|
||||
print()
|
||||
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -82,7 +62,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: statement_embedding_index")
|
||||
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -94,7 +74,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: chunk_embedding_index")
|
||||
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -106,7 +86,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: entity_embedding_index")
|
||||
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -118,8 +98,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: summary_embedding_index")
|
||||
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
@@ -129,8 +108,7 @@ async def create_vector_indexes():
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
""")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
@@ -142,91 +120,15 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: dialogue_embedding_index")
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
FOR (c:Community)
|
||||
ON c.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
|
||||
print("\nVector indexes created successfully!")
|
||||
print("\nExpected performance improvement:")
|
||||
print(" Before: ~1.4s for embedding search")
|
||||
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating vector indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_config_id_indexes():
|
||||
"""Create indexes on config_id fields for improved query performance.
|
||||
|
||||
These indexes enable fast filtering of nodes by configuration ID,
|
||||
which is essential for configuration isolation and multi-tenant scenarios.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Config ID Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
|
||||
FOR (d:Dialogue) ON (d.config_id)
|
||||
""")
|
||||
print("✓ Created: dialogue_config_id_index")
|
||||
|
||||
# Statement.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX statement_config_id_index IF NOT EXISTS
|
||||
FOR (s:Statement) ON (s.config_id)
|
||||
""")
|
||||
print("✓ Created: statement_config_id_index")
|
||||
|
||||
# ExtractedEntity.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX entity_config_id_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity) ON (e.config_id)
|
||||
""")
|
||||
print("✓ Created: entity_config_id_index")
|
||||
|
||||
# MemorySummary.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX summary_config_id_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary) ON (m.config_id)
|
||||
""")
|
||||
print("✓ Created: summary_config_id_index")
|
||||
|
||||
print("\nConfig ID indexes created successfully!")
|
||||
print("These indexes enable fast filtering by configuration ID.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating config_id indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Unique Constraints")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -234,8 +136,7 @@ async def create_unique_constraints():
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: dialog_id_unique")
|
||||
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -243,8 +144,7 @@ async def create_unique_constraints():
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: statement_id_unique")
|
||||
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -252,112 +152,13 @@ async def create_unique_constraints():
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: chunk_id_unique")
|
||||
|
||||
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating unique constraints: {e}")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
print("\n" + "=" * 70)
|
||||
print("Neo4j Index & Constraint Setup")
|
||||
print("=" * 70)
|
||||
print("This will create:")
|
||||
print(" 1. Full-text indexes (for keyword/BM25 search)")
|
||||
print(" 2. Vector indexes (for embedding similarity search)")
|
||||
print(" 3. Config ID indexes (for configuration isolation)")
|
||||
print(" 4. Unique constraints (for data integrity)")
|
||||
print("=" * 70)
|
||||
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_config_id_indexes()
|
||||
await create_unique_constraints()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
print("=" * 70)
|
||||
print("\nTo verify, run in Neo4j Browser:")
|
||||
print(" SHOW INDEXES")
|
||||
print(" SHOW CONSTRAINTS")
|
||||
print()
|
||||
|
||||
|
||||
async def check_indexes():
|
||||
"""Check what indexes currently exist."""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Checking Existing Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
query = "SHOW INDEXES"
|
||||
result = await connector.execute_query(query)
|
||||
|
||||
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
|
||||
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
|
||||
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
|
||||
|
||||
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
|
||||
for idx in fulltext_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nVector indexes: {len(vector_indexes)}")
|
||||
for idx in vector_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
|
||||
for idx in range_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
if not vector_indexes:
|
||||
print("\n⚠️ WARNING: No vector indexes found!")
|
||||
print(" Embedding search will be VERY SLOW (~1.4s)")
|
||||
print(" Run: python create_indexes.py")
|
||||
|
||||
# Check for config_id indexes
|
||||
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
|
||||
if len(config_id_indexes) < 4:
|
||||
print("\n⚠️ WARNING: Not all config_id indexes found!")
|
||||
print(f" Expected 4, found {len(config_id_indexes)}")
|
||||
print(" Run: python create_indexes.py config_id")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
if command == "check":
|
||||
asyncio.run(check_indexes())
|
||||
elif command == "fulltext":
|
||||
asyncio.run(create_fulltext_indexes())
|
||||
elif command == "vector":
|
||||
asyncio.run(create_vector_indexes())
|
||||
elif command == "config_id":
|
||||
asyncio.run(create_config_id_indexes())
|
||||
elif command == "constraints":
|
||||
asyncio.run(create_unique_constraints())
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("\nUsage:")
|
||||
print(" python create_indexes.py # Create all indexes")
|
||||
print(" python create_indexes.py check # Check existing indexes")
|
||||
print(" python create_indexes.py fulltext # Create only full-text indexes")
|
||||
print(" python create_indexes.py vector # Create only vector indexes")
|
||||
print(" python create_indexes.py config_id # Create only config_id indexes")
|
||||
print(" python create_indexes.py constraints # Create only constraints")
|
||||
else:
|
||||
asyncio.run(create_all_indexes())
|
||||
|
||||
|
||||
@@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
WITH e, score
|
||||
UNION
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
AND e.aliases IS NOT NULL
|
||||
AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
WITH e,
|
||||
WITH collect({entity: e, score: score}) AS fulltextResults
|
||||
|
||||
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||
AND ae.aliases IS NOT NULL
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||
|
||||
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||
CASE
|
||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
ELSE 0.8
|
||||
END AS score
|
||||
END
|
||||
}]) AS row
|
||||
WITH row.entity AS e, row.score AS score
|
||||
WITH DISTINCT e, MAX(score) AS score
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
|
||||
@@ -158,22 +158,26 @@ class UserRepository:
|
||||
raise
|
||||
|
||||
def get_users_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[User]:
|
||||
"""获取租户下的用户列表"""
|
||||
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
|
||||
if is_superuser is not None:
|
||||
query = query.filter(User.is_superuser == is_superuser)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
@@ -181,7 +185,7 @@ class UserRepository:
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
||||
return users
|
||||
@@ -190,18 +194,22 @@ class UserRepository:
|
||||
raise
|
||||
|
||||
def count_users_by_tenant(
|
||||
self,
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
try:
|
||||
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
|
||||
if is_superuser is not None:
|
||||
query = query.filter(User.is_superuser == is_superuser)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
@@ -209,7 +217,7 @@ class UserRepository:
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return query.scalar()
|
||||
except Exception as e:
|
||||
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||
|
||||
@@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel):
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(
|
||||
default_factory=lambda: MemoryConfig(enabled=True),
|
||||
default_factory=lambda: MemoryConfig(enabled=False),
|
||||
description="对话历史记忆配置"
|
||||
)
|
||||
|
||||
|
||||
@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
|
||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||
|
||||
|
||||
class PageInfo(BaseModel):
|
||||
"""分页信息模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
page: int = Field(..., description="当前页码(从1开始)")
|
||||
pagesize: int = Field(..., description="每页数量")
|
||||
total: int = Field(..., description="总记录数")
|
||||
hasnext: bool = Field(..., description="是否有下一页")
|
||||
|
||||
|
||||
class PendingNodesResponse(BaseModel):
|
||||
"""待遗忘节点列表响应模型(独立分页接口)"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
|
||||
page: PageInfo = Field(..., description="分页信息")
|
||||
|
||||
|
||||
class ForgettingStatsResponse(BaseModel):
|
||||
"""遗忘引擎统计信息响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import field
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
@@ -20,6 +20,7 @@ class UserCreate(UserBase):
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
phone: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_superuser: Optional[bool] = None
|
||||
|
||||
@@ -85,6 +86,8 @@ class User(UserBase):
|
||||
current_workspace_name: Optional[str] = None
|
||||
role: Optional[WorkspaceRole] = None
|
||||
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
||||
phone: Optional[str] = None # 用户电话
|
||||
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
|
||||
|
||||
# 将 datetime 转换为毫秒时间戳
|
||||
@validator("created_at", pre=True)
|
||||
|
||||
@@ -141,13 +141,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
@@ -378,13 +378,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
|
||||
128
api/app/services/app_log_service.py
Normal file
128
api/app/services/app_log_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""应用日志服务层"""
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AppLogService:
|
||||
"""应用日志服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_repository = ConversationRepository(db)
|
||||
self.message_repository = MessageRepository(db)
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
is_draft: Optional[bool] = None,
|
||||
) -> Tuple[list[Conversation], int]:
|
||||
"""
|
||||
查询应用日志会话列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
page: 页码(从 1 开始)
|
||||
pagesize: 每页数量
|
||||
is_draft: 是否草稿会话(None 表示不过滤)
|
||||
|
||||
Returns:
|
||||
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
|
||||
# 使用 Repository 查询
|
||||
conversations, total = self.conversation_repository.list_app_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
is_draft=is_draft,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"total": total,
|
||||
"returned": len(conversations)
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total
|
||||
|
||||
def get_conversation_detail(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Conversation:
|
||||
"""
|
||||
查询会话详情(包含消息)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
conversation_id: 会话 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
Conversation: 包含消息的会话对象
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当会话不存在时
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
# 查询会话
|
||||
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = self.message_repository.get_messages_by_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 将消息附加到会话对象
|
||||
conversation.messages = messages
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
@@ -1084,7 +1084,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
return cleaned
|
||||
|
||||
exists = self.db.query(
|
||||
@@ -1096,7 +1095,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -1684,15 +1682,15 @@ class AppService:
|
||||
|
||||
return config.config_id
|
||||
|
||||
def _update_endusers_memory_config_by_workspace(
|
||||
def _update_endusers_memory_config_by_app(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
@@ -1701,8 +1699,8 @@ class AppService:
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
||||
workspace_id=workspace_id,
|
||||
updated_count = repo.batch_update_memory_config_id_by_app(
|
||||
app_id=app_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
@@ -1753,12 +1751,16 @@ class AppService:
|
||||
|
||||
miss_params = []
|
||||
if agent_cfg.default_model_config_id is None:
|
||||
miss_params.append("model config")
|
||||
miss_params.append("模型配置")
|
||||
|
||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||
miss_params.append("memory config")
|
||||
miss_params.append("记忆配置")
|
||||
if miss_params:
|
||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||
raise BusinessException(
|
||||
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
|
||||
BizCode.CONFIG_MISSING,
|
||||
context={"missing_params": miss_params},
|
||||
)
|
||||
|
||||
config = {
|
||||
"system_prompt": agent_cfg.system_prompt,
|
||||
@@ -1877,8 +1879,8 @@ class AppService:
|
||||
if memory_config_id:
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if app:
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(
|
||||
app.workspace_id, memory_config_id
|
||||
updated_count = self._update_endusers_memory_config_by_app(
|
||||
app_id, memory_config_id
|
||||
)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||
@@ -2014,7 +2016,7 @@ class AppService:
|
||||
|
||||
if memory_config_id:
|
||||
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
|
||||
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
|
||||
@@ -214,7 +214,7 @@ class ConversationService:
|
||||
|
||||
conversation.message_count += 1
|
||||
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
if conversation.message_count <= 2 and role == "user":
|
||||
conversation.title = (
|
||||
content[:50] + ("..." if len(content) > 50 else "")
|
||||
)
|
||||
|
||||
@@ -448,15 +448,16 @@ class AgentRunService:
|
||||
features_config: Dict[str, Any],
|
||||
is_new_conversation: bool,
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
||||
if not is_new_conversation:
|
||||
return None
|
||||
return None, None
|
||||
opening = features_config.get("opening_statement", {})
|
||||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||
return None
|
||||
return None, None
|
||||
|
||||
statement = opening["statement"]
|
||||
suggested_questions = opening["suggested_questions"]
|
||||
|
||||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||||
if variables:
|
||||
@@ -464,7 +465,7 @@ class AgentRunService:
|
||||
placeholder = f"{{{{{var_name}}}}}"
|
||||
statement = statement.replace(placeholder, str(var_value))
|
||||
|
||||
return statement
|
||||
return statement, suggested_questions
|
||||
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
@@ -598,13 +599,16 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -839,14 +843,17 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
sub_agent=sub_agent,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -1050,7 +1057,8 @@ class AgentRunService:
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str],
|
||||
sub_agent: bool = False,
|
||||
opening_statement: Optional[str] = None
|
||||
opening_statement: Optional[str] = None,
|
||||
suggested_questions: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""确保会话存在(创建或验证)
|
||||
|
||||
@@ -1061,6 +1069,7 @@ class AgentRunService:
|
||||
user_id: 用户ID
|
||||
sub_agent: 是否为子代理
|
||||
opening_statement: 开场白(新会话时作为第一条消息写入)
|
||||
suggested_questions: 预设问题列表
|
||||
|
||||
Returns:
|
||||
str: 会话ID
|
||||
@@ -1104,7 +1113,7 @@ class AgentRunService:
|
||||
conversation_id=uuid.UUID(new_conv_id),
|
||||
role="assistant",
|
||||
content=opening_statement,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
||||
|
||||
|
||||
@@ -204,30 +204,35 @@ class MemoryForgetService:
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int,
|
||||
limit: int = 20
|
||||
) -> list[Dict[str, Any]]:
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
||||
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
limit: 返回节点数量限制
|
||||
|
||||
page: 页码(可选,从1开始)
|
||||
pagesize: 每页数量(可选)
|
||||
|
||||
Returns:
|
||||
list: 待遗忘节点列表
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息(分页时)
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||
|
||||
query = """
|
||||
|
||||
# 基础查询(用于获取总数)
|
||||
count_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
@@ -235,10 +240,22 @@ class MemoryForgetService:
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
# 数据查询
|
||||
data_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
elementId(n) as node_id,
|
||||
labels(n)[0] as node_type,
|
||||
CASE
|
||||
CASE
|
||||
WHEN n:Statement THEN n.statement
|
||||
WHEN n:ExtractedEntity THEN n.name
|
||||
WHEN n:MemorySummary THEN n.content
|
||||
@@ -247,18 +264,32 @@ class MemoryForgetService:
|
||||
n.activation_value as activation_value,
|
||||
n.last_access_time as last_access_time
|
||||
ORDER BY n.activation_value ASC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
|
||||
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
data_query += " SKIP $skip LIMIT $limit"
|
||||
|
||||
params = {
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str,
|
||||
'limit': limit
|
||||
'min_access_time_str': min_access_time_str
|
||||
}
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
|
||||
# 获取总数(分页时需要)
|
||||
total = 0
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
count_results = await connector.execute_query(count_query, **params)
|
||||
if count_results:
|
||||
total = count_results[0]['total']
|
||||
|
||||
# 添加分页参数
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
params['skip'] = (page - 1) * pagesize
|
||||
params['limit'] = pagesize
|
||||
|
||||
results = await connector.execute_query(data_query, **params)
|
||||
|
||||
pending_nodes = []
|
||||
for result in results:
|
||||
# 将节点类型标签转换为小写
|
||||
@@ -267,7 +298,7 @@ class MemoryForgetService:
|
||||
node_type_label = 'entity'
|
||||
elif node_type_label == 'memorysummary':
|
||||
node_type_label = 'summary'
|
||||
|
||||
|
||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||
last_access_time = result['last_access_time']
|
||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||
@@ -278,7 +309,7 @@ class MemoryForgetService:
|
||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||
else:
|
||||
last_access_timestamp = 0
|
||||
|
||||
|
||||
pending_nodes.append({
|
||||
'node_id': str(result['node_id']),
|
||||
'node_type': node_type_label,
|
||||
@@ -286,8 +317,20 @@ class MemoryForgetService:
|
||||
'activation_value': result['activation_value'],
|
||||
'last_access_time': last_access_timestamp
|
||||
})
|
||||
|
||||
return pending_nodes
|
||||
|
||||
# 构建返回结果
|
||||
result: Dict[str, Any] = {'items': pending_nodes}
|
||||
|
||||
# 如果启用分页,添加分页信息
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
result['page'] = {
|
||||
'page': page,
|
||||
'pagesize': pagesize,
|
||||
'total': total,
|
||||
'hasnext': (page * pagesize) < total
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
@@ -636,7 +679,7 @@ class MemoryForgetService:
|
||||
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||
# 获取待遗忘节点列表
|
||||
pending_nodes = []
|
||||
try:
|
||||
if end_user_id:
|
||||
@@ -652,8 +695,7 @@ class MemoryForgetService:
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
limit=20
|
||||
min_days_since_access=int(min_days)
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
||||
@@ -661,24 +703,79 @@ class MemoryForgetService:
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 构建统计信息
|
||||
|
||||
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||
stats = {
|
||||
'activation_metrics': activation_metrics,
|
||||
'node_distribution': node_distribution,
|
||||
'recent_trends': recent_trends,
|
||||
'pending_nodes': pending_nodes,
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
||||
f"trend_days={len(recent_trends)}"
|
||||
)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def get_pending_nodes(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
config_id: Optional[UUID] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 组ID(必填)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
|
||||
Returns:
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
connector = forgetting_scheduler.connector
|
||||
forgetting_threshold = config['forgetting_threshold']
|
||||
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
api_logger.warning(
|
||||
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||
)
|
||||
min_days = 7
|
||||
|
||||
# 调用内部方法获取分页数据
|
||||
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||
)
|
||||
|
||||
return pending_nodes_result
|
||||
|
||||
async def get_forgetting_curve(
|
||||
self,
|
||||
db: Session,
|
||||
|
||||
@@ -12,6 +12,9 @@ import base64
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import olefile
|
||||
import struct
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
@@ -602,31 +605,75 @@ class MultimodalService:
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
return '\n'.join(p.text for p in doc.paragraphs)
|
||||
text_lines = []
|
||||
for p in doc.paragraphs:
|
||||
text = p.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
for cell in row.cells:
|
||||
text = cell.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
full_text = "\n".join(text_lines)
|
||||
return full_text.strip() or "[docx 文件无文本内容]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 docx 文本失败: {e}")
|
||||
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
|
||||
return f"[docx 提取失败: {str(e)}]"
|
||||
|
||||
# 旧版 .doc(OLE2 格式)
|
||||
# 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
|
||||
try:
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||
if not ole.exists('WordDocument'):
|
||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||
stream = ole.openstream('WordDocument').read()
|
||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||
# 简单提取:过滤出可打印字符段
|
||||
try:
|
||||
text = stream.decode('utf-16-le', errors='ignore')
|
||||
except Exception:
|
||||
text = stream.decode('latin-1', errors='ignore')
|
||||
# 过滤控制字符,保留可打印内容
|
||||
import re
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
text = re.sub(r' +', ' ', text).strip()
|
||||
word_stream = ole.openstream('WordDocument').read()
|
||||
|
||||
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
|
||||
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
|
||||
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
|
||||
table_stream = ole.openstream(table_name).read()
|
||||
|
||||
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
|
||||
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
|
||||
clx = table_stream[fc_clx: fc_clx + lcb_clx]
|
||||
|
||||
# 解析 CLX,找到 PlcPcd(piece table)
|
||||
i, plc_pcd = 0, None
|
||||
while i < len(clx):
|
||||
clxt = clx[i]
|
||||
if clxt == 0x01:
|
||||
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
|
||||
elif clxt == 0x02:
|
||||
cb = struct.unpack_from('<I', clx, i + 1)[0]
|
||||
plc_pcd = clx[i + 5: i + 5 + cb]
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if plc_pcd is None:
|
||||
raise ValueError("PlcPcd not found")
|
||||
|
||||
# PlcPcd: (n+1) 个 CP(4字节)+ n 个 PCD(8字节)
|
||||
n_pieces = (len(plc_pcd) - 4) // 12
|
||||
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
|
||||
|
||||
parts = []
|
||||
for k in range(n_pieces):
|
||||
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
|
||||
is_ansi = bool(fc_value & 0x40000000)
|
||||
fc = fc_value & 0x3FFFFFFF
|
||||
char_count = cp_array[k + 1] - cp_array[k]
|
||||
|
||||
if is_ansi:
|
||||
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
|
||||
else:
|
||||
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
|
||||
|
||||
ole.close()
|
||||
return text
|
||||
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
|
||||
return result.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取 doc 文本失败: {e}")
|
||||
return f"[doc 提取失败: {str(e)}]"
|
||||
|
||||
@@ -138,7 +138,7 @@ class TenantService:
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"删除租户失败: {str(e)}")
|
||||
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
# 租户用户管理
|
||||
def get_tenant_users(
|
||||
@@ -147,6 +147,7 @@ class TenantService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[UserModel]:
|
||||
"""获取租户下的用户列表"""
|
||||
@@ -155,6 +156,7 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
@@ -162,12 +164,14 @@ class TenantService:
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
return self.user_repo.count_users_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
|
||||
@@ -472,6 +472,21 @@ class UserMemoryService:
|
||||
# 定义允许更新的字段白名单
|
||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||
|
||||
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
||||
_user_placeholder_names = {'用户', '我', 'User', 'I'}
|
||||
|
||||
# 过滤 other_name:不允许设置为占位名称
|
||||
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
||||
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
|
||||
del update_data['other_name']
|
||||
|
||||
# 过滤 aliases:移除占位名称和非字符串值
|
||||
if 'aliases' in update_data and update_data['aliases']:
|
||||
update_data['aliases'] = [
|
||||
a for a in update_data['aliases']
|
||||
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
|
||||
]
|
||||
|
||||
# 检查是否更新了 aliases 字段
|
||||
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -38,12 +37,10 @@ from app.db import get_db, get_db_context
|
||||
from app.models import Document, File, Knowledge
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
from app.utils.redis_lock import RedisLock
|
||||
from app.utils.redis_lock import RedisFairLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
|
||||
|
||||
def set_asyncio_event_loop():
|
||||
"""Set the asyncio event loop for the current thread."""
|
||||
"""Ensure an open asyncio event loop exists for the current thread.
|
||||
|
||||
Reuses the existing event loop if one is available and still open.
|
||||
Creates and installs a new event loop only when the current one is
|
||||
closed or missing (e.g. after ``_shutdown_loop_gracefully``).
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
@@ -116,6 +118,30 @@ def set_asyncio_event_loop():
|
||||
return loop
|
||||
|
||||
|
||||
def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop):
|
||||
"""Gracefully shutdown pending async generators and tasks on the event loop.
|
||||
|
||||
This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||
by giving pending aclose() coroutines a chance to run before the loop is discarded.
|
||||
|
||||
Note: This only tears down the given loop. Callers that need a fresh event
|
||||
loop afterwards should use ``set_asyncio_event_loop()`` explicitly.
|
||||
"""
|
||||
try:
|
||||
# Cancel and collect all remaining tasks
|
||||
all_tasks = asyncio.all_tasks(loop)
|
||||
if all_tasks:
|
||||
for task in all_tasks:
|
||||
task.cancel()
|
||||
loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True))
|
||||
# Shutdown async generators (triggers __aclose__ on httpx clients etc.)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
def process_item(item: dict):
|
||||
"""
|
||||
@@ -1148,8 +1174,28 @@ def write_message_task(
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
|
||||
redis_client = get_sync_redis_client()
|
||||
lock = None
|
||||
if redis_client is not None:
|
||||
lock = RedisFairLock(
|
||||
key=f"memory_write:{end_user_id}",
|
||||
redis_client=redis_client,
|
||||
expire=600,
|
||||
timeout=3600,
|
||||
auto_renewal=True,
|
||||
)
|
||||
if not lock.acquire():
|
||||
logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}")
|
||||
return {
|
||||
"status": "SKIPPED",
|
||||
"error": "acquire lock timeout",
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": str(config_id),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
try:
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
@@ -1158,7 +1204,6 @@ def write_message_task(
|
||||
logger.info(f"[CELERY WRITE] Task completed successfully "
|
||||
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
|
||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||
try:
|
||||
_r = get_sync_redis_client()
|
||||
if _r is not None:
|
||||
@@ -1199,6 +1244,15 @@ def write_message_task(
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
finally:
|
||||
if lock is not None:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"[CELERY WRITE] 释放锁失败: {e}")
|
||||
# Gracefully shutdown the event loop to prevent
|
||||
# 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||
_shutdown_loop_gracefully(loop)
|
||||
|
||||
|
||||
# unused task
|
||||
@@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# unused task
|
||||
@@ -1,6 +1,7 @@
|
||||
import redis
|
||||
import uuid
|
||||
import time
|
||||
import threading
|
||||
|
||||
UNLOCK_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
@@ -10,45 +11,136 @@ else
|
||||
end
|
||||
"""
|
||||
|
||||
RENEW_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
class RedisLock:
|
||||
CLEANUP_DEAD_HEAD_SCRIPT = """
|
||||
local queue_key = KEYS[1]
|
||||
local lock_key = KEYS[2]
|
||||
|
||||
local first = redis.call("lindex", queue_key, 0)
|
||||
if not first then
|
||||
return 0
|
||||
end
|
||||
|
||||
if redis.call("exists", lock_key) == 1 then
|
||||
return 0
|
||||
end
|
||||
|
||||
redis.call("lpop", queue_key)
|
||||
return 1
|
||||
"""
|
||||
|
||||
SAFE_RELEASE_QUEUE_SCRIPT = """
|
||||
local queue_key = KEYS[1]
|
||||
local value = ARGV[1]
|
||||
|
||||
local first = redis.call("lindex", queue_key, 0)
|
||||
if first == value then
|
||||
redis.call("lpop", queue_key)
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
def _ensure_str(val):
|
||||
"""统一将 Redis 返回值转为 str,兼容 decode_responses=True/False"""
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val.decode("utf-8")
|
||||
return str(val)
|
||||
|
||||
|
||||
class RedisFairLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
redis_client: redis.StrictRedis,
|
||||
expire: int = 60,
|
||||
retry_interval: float = 0.1,
|
||||
timeout: float = 30
|
||||
|
||||
expire: int = 30,
|
||||
retry_interval: float = 0.05,
|
||||
timeout: float = 600,
|
||||
auto_renewal: bool = True
|
||||
):
|
||||
self.key = key
|
||||
self.expire = expire
|
||||
self.queue_key = f"{key}:queue"
|
||||
self.value = str(uuid.uuid4())
|
||||
self._locked = False
|
||||
self.expire = expire
|
||||
self.retry_interval = retry_interval
|
||||
self.timeout = timeout
|
||||
self.redis_client = redis_client
|
||||
self.redis = redis_client
|
||||
self._locked = False
|
||||
self.auto_renewal = auto_renewal
|
||||
self._renew_thread = None
|
||||
self._stop_renew = threading.Event()
|
||||
|
||||
def acquire(self) -> bool:
|
||||
def acquire(self):
|
||||
start = time.time()
|
||||
|
||||
self.redis.rpush(self.queue_key, self.value)
|
||||
|
||||
while True:
|
||||
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
|
||||
if ok:
|
||||
self._locked = True
|
||||
return True
|
||||
if time.time() - start >= self.timeout:
|
||||
first = _ensure_str(self.redis.lindex(self.queue_key, 0))
|
||||
|
||||
if first == self.value:
|
||||
ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
|
||||
if ok:
|
||||
self._locked = True
|
||||
|
||||
if self.auto_renewal:
|
||||
self._start_renewal()
|
||||
return True
|
||||
|
||||
if first:
|
||||
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
|
||||
|
||||
if time.time() - start > self.timeout:
|
||||
self.redis.lrem(self.queue_key, 0, self.value)
|
||||
return False
|
||||
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def _renewal_loop(self):
|
||||
while not self._stop_renew.is_set():
|
||||
time.sleep(self.expire / 3)
|
||||
if self._stop_renew.is_set():
|
||||
break
|
||||
|
||||
self.redis.eval(
|
||||
RENEW_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value,
|
||||
str(self.expire)
|
||||
)
|
||||
|
||||
def _start_renewal(self):
|
||||
self._stop_renew = threading.Event()
|
||||
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
|
||||
self._renew_thread.start()
|
||||
|
||||
def _stop_renewal(self):
|
||||
self._stop_renew.set()
|
||||
if self._renew_thread:
|
||||
self._renew_thread.join(timeout=1)
|
||||
|
||||
def release(self):
|
||||
if not self._locked:
|
||||
return
|
||||
self.redis_client.eval(
|
||||
UNLOCK_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value
|
||||
)
|
||||
|
||||
if self.auto_renewal:
|
||||
self._stop_renewal()
|
||||
|
||||
self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
|
||||
|
||||
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
|
||||
|
||||
self._locked = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -59,3 +151,4 @@ class RedisLock:
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.release()
|
||||
|
||||
|
||||
30
api/migrations/versions/4e89970f9e7c_202603271515.py
Normal file
30
api/migrations/versions/4e89970f9e7c_202603271515.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""202603271515
|
||||
|
||||
Revision ID: 4e89970f9e7c
|
||||
Revises: 6b8a461148ff
|
||||
Create Date: 2026-03-27 15:12:27.518344
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4e89970f9e7c'
|
||||
down_revision: Union[str, None] = '6b8a461148ff'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'phone')
|
||||
# ### end Alembic commands ###
|
||||
@@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => {
|
||||
export const getForgetStats = (end_user_id: string) => {
|
||||
return request.get(`/memory/forget-memory/stats`, { end_user_id })
|
||||
}
|
||||
// 获取带遗忘节点列表
|
||||
export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes'
|
||||
// Implicit Memory - Preferences
|
||||
export const getImplicitPreferences = (end_user_id: string) => {
|
||||
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)
|
||||
|
||||
@@ -37,11 +37,11 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
const prevDataLengthRef = useRef(data.length);
|
||||
const isScrolledToBottomRef = useRef(true);
|
||||
const audioRef = useRef<HTMLAudioElement | null>(null)
|
||||
const [playingIndex, setPlayingIndex] = useState<number | null>(null)
|
||||
const [playingIndex, setPlayingIndex] = useState<string | null>(null)
|
||||
|
||||
const handlePlay = (index: number, audio_url: string, audio_status?: string) => {
|
||||
if (audio_status !== 'completed' && !audio_status) return
|
||||
if (playingIndex === index) {
|
||||
const handlePlay = (audio_url: string, audio_status?: string) => {
|
||||
if (audio_status !== 'completed' && typeof audio_status === 'string') return
|
||||
if (playingIndex === audio_url) {
|
||||
audioRef.current?.pause()
|
||||
setPlayingIndex(null)
|
||||
return
|
||||
@@ -52,7 +52,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
const audio = new Audio(audio_url)
|
||||
audioRef.current = audio
|
||||
audio.play()
|
||||
setPlayingIndex(index)
|
||||
setPlayingIndex(audio_url)
|
||||
audio.onended = () => setPlayingIndex(null)
|
||||
}
|
||||
|
||||
@@ -79,12 +79,16 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
|
||||
// Auto-scroll to bottom when data changes to show latest messages
|
||||
// When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom
|
||||
// When data array length changes, auto-scroll to bottom
|
||||
// If already scrolled to bottom, will auto-scroll to bottom
|
||||
useEffect(() => {
|
||||
if (playingIndex && !data.some(item => item.meta_data?.audio_url === playingIndex)) {
|
||||
audioRef.current?.pause()
|
||||
setPlayingIndex(null)
|
||||
}
|
||||
setTimeout(() => {
|
||||
if (scrollContainerRef.current) {
|
||||
// Auto-scroll if data length changed OR user is currently at bottom
|
||||
@@ -204,16 +208,16 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
{item.meta_data?.audio_url && <>
|
||||
<Divider className="rb:my-3!" />
|
||||
<Space size={12} className="rb:pb-2 rb:pl-1">
|
||||
{playingIndex !== index && item.meta_data?.audio_status === 'pending'
|
||||
{playingIndex !== item.meta_data?.audio_url && item.meta_data?.audio_status === 'pending'
|
||||
? <Spin />
|
||||
: playingIndex !== index
|
||||
: playingIndex !== item.meta_data?.audio_url
|
||||
? <SoundOutlined className={clsx("rb:cursor-pointer rb:size-5.5", {
|
||||
'rb:text-[#FF5D34]': item.meta_data?.audio_status === 'error',
|
||||
'rb:hover:text-[#155EEF]!': !item.meta_data?.audio_status || !['pending', 'error'].includes(item.meta_data?.audio_status)
|
||||
})} onClick={() => handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} />
|
||||
})} onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} />
|
||||
: <div
|
||||
className="rb:size-5.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/audio_ing.gif')]"
|
||||
onClick={() => handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)}
|
||||
onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)}
|
||||
/>
|
||||
}
|
||||
</Space>
|
||||
|
||||
@@ -24,16 +24,17 @@ interface BodyWrapperProps {
|
||||
/** Whether to show loading state */
|
||||
loading?: boolean
|
||||
/** Whether the content is empty */
|
||||
empty: boolean
|
||||
empty: boolean;
|
||||
className?: string;
|
||||
}
|
||||
const BodyWrapper: FC<BodyWrapperProps> = ({ children, loading = false, empty }) => {
|
||||
const BodyWrapper: FC<BodyWrapperProps> = ({ children, loading = false, empty, className = 'rb:max-h-[calc(100%-48px)]!' }) => {
|
||||
// Show loading spinner while data is being fetched
|
||||
if (loading) {
|
||||
return <PageLoading />
|
||||
return <PageLoading className={className} />
|
||||
}
|
||||
// Show empty state when no data is available
|
||||
if (!loading && empty) {
|
||||
return <PageEmpty />
|
||||
return <PageEmpty className={className} />
|
||||
}
|
||||
// Render actual content when data is loaded and available
|
||||
return children
|
||||
|
||||
@@ -128,6 +128,7 @@ const Menu: FC<{
|
||||
|
||||
/** Filter menus based on user role and source */
|
||||
useEffect(() => {
|
||||
if (!user) return
|
||||
let menuList: MenuItem[] = []
|
||||
|
||||
if (user.role === 'member' && source === 'space') {
|
||||
@@ -136,7 +137,7 @@ const Menu: FC<{
|
||||
menuList = allMenus[source] || []
|
||||
}
|
||||
|
||||
const noAuthList = ['user', 'pricing'].filter(vo => !user.permissions?.includes(vo) && !user.permissions?.includes('all'))
|
||||
const noAuthList = ['user', 'pricing'].filter(vo => (Array.isArray(user.permissions) && !user.permissions?.includes(vo) && !user.permissions?.includes('all')) || !Array.isArray(user.permissions))
|
||||
|
||||
if (noAuthList && !noAuthList?.includes('all')) {
|
||||
const filterMenus = (list: MenuItem[]): MenuItem[] =>{
|
||||
|
||||
@@ -1509,7 +1509,7 @@ export const en = {
|
||||
EPISODIC_MEMORY: 'Episodic Memory',
|
||||
FORGET_MEMORY: 'Forget Memory',
|
||||
|
||||
endUserProfile: 'Profile',
|
||||
endUserProfile: 'Permanent Memory',
|
||||
editEndUserProfile: 'Edit',
|
||||
other_name: 'Name',
|
||||
position: 'Position',
|
||||
@@ -1828,6 +1828,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
memoryTipTitle: 'Are you sure you want to enable conversation memory? Conversations will be saved to the memory store.',
|
||||
stopAudioRecorder: 'Stop Recording',
|
||||
startAudioRecorder: 'Start Recording',
|
||||
citations: 'Citations',
|
||||
},
|
||||
login: {
|
||||
title: 'Red Bear Memory Science',
|
||||
|
||||
@@ -1507,7 +1507,7 @@ export const zh = {
|
||||
EPISODIC_MEMORY: '情景记忆',
|
||||
FORGET_MEMORY: '遗忘记忆',
|
||||
|
||||
endUserProfile: '核心档案',
|
||||
endUserProfile: '永久记忆',
|
||||
editEndUserProfile: '编辑',
|
||||
other_name: '名称',
|
||||
position: '职位',
|
||||
|
||||
@@ -377,9 +377,10 @@ body {
|
||||
.ant-input-filled,
|
||||
.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector {
|
||||
background-color: #FFFFFF;
|
||||
border-color: #FFFFFF;
|
||||
}
|
||||
.ant-input-filled:hover,
|
||||
.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector {
|
||||
.ant-select-filled:not(.ant-select-disabled):not(.ant-select-customize-input):not(.ant-pagination-size-changer):hover .ant-select-selector {
|
||||
background-color: #FFFFFF;
|
||||
border-color: #171719;
|
||||
}
|
||||
@@ -402,7 +403,7 @@ body {
|
||||
color: #FF5D34;
|
||||
}
|
||||
|
||||
.spin.ant-spin-nested-loading .ant-spin-container::after {
|
||||
.spin .ant-spin-nested-loading .ant-spin-container::after {
|
||||
background: transparent;
|
||||
}
|
||||
.upload-block,
|
||||
|
||||
@@ -34,16 +34,16 @@ const Statistics: FC = () => {
|
||||
className: 'rb:text-[#212332]'
|
||||
},
|
||||
{
|
||||
title: t('user.createTime'),
|
||||
title: t('application.created_at'),
|
||||
dataIndex: 'created_at',
|
||||
key: 'created_at',
|
||||
render: (createdAt: string) => formatDateTime(createdAt, 'YYYY-MM-DD HH:mm:ss'),
|
||||
},
|
||||
{
|
||||
title: t('user.lastLoginTime'),
|
||||
dataIndex: 'last_login_at',
|
||||
key: 'last_login_at',
|
||||
render: (lastLoginAt: string) => lastLoginAt ? formatDateTime(lastLoginAt, 'YYYY-MM-DD HH:mm:ss') : '-',
|
||||
title: t('common.updated_at'),
|
||||
dataIndex: 'updated_at',
|
||||
key: 'updated_at',
|
||||
render: (updatedAt: string) => updatedAt ? formatDateTime(updatedAt, 'YYYY-MM-DD HH:mm:ss') : '-',
|
||||
},
|
||||
{
|
||||
title: t('common.operation'),
|
||||
|
||||
@@ -220,31 +220,31 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
/>
|
||||
<Popover content={t('workflow.clear')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/clear.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/clear.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={clear}
|
||||
></div>
|
||||
</Popover>
|
||||
<Popover content={t('workflow.addvariable')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/variable.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/variable.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={addvariable}
|
||||
></div>
|
||||
</Popover>
|
||||
<Popover content={t('workflow.run')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/run.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/run.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={run}
|
||||
></div>
|
||||
</Popover>
|
||||
<Popover content={t('workflow.save')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/save.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/save.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={save}
|
||||
></div>
|
||||
</Popover>
|
||||
<Popover content={t('common.return')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/return.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/return.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={goToApplication}
|
||||
></div>
|
||||
</Popover>
|
||||
|
||||
@@ -49,7 +49,7 @@ const FeaturesConfig: FC<FeaturesConfigProps> = ({
|
||||
?
|
||||
<Popover content={t('application.features')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
|
||||
<div
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/features.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/features.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
|
||||
onClick={handleFeaturesConfig}
|
||||
></div>
|
||||
</Popover>
|
||||
|
||||
@@ -216,7 +216,7 @@ const ApplicationManagement: React.FC = () => {
|
||||
'rb:text-[#155EEF]': key === 'type',
|
||||
})}>
|
||||
{key === 'source' && item.is_shared
|
||||
? t('application.shared')
|
||||
? item.source_workspace_name
|
||||
: key === 'source' && !item.is_shared
|
||||
? t('application.configuration')
|
||||
: key === 'created_at'
|
||||
|
||||
@@ -64,6 +64,13 @@ const Conversation: FC = () => {
|
||||
const [config, setConfig] = useState<Record<string, any>>({})
|
||||
const [audioStatusMap, setAudioStatusMap] = useState<Record<string, string>>({})
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
audioPollingRef.current.forEach((timer) => clearInterval(timer))
|
||||
audioPollingRef.current.clear()
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const shareToken = localStorage.getItem(`shareToken_${token}`)
|
||||
setShareToken(shareToken)
|
||||
@@ -144,13 +151,29 @@ const Conversation: FC = () => {
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
audioPollingRef.current.forEach((timer) => clearInterval(timer))
|
||||
audioPollingRef.current.clear()
|
||||
if (conversation_id) {
|
||||
getConversationDetail(token as string, conversation_id)
|
||||
.then(res => {
|
||||
const response = res as { messages: ChatItem[] }
|
||||
setChatList(response?.messages || [])
|
||||
const messages = response?.messages || []
|
||||
const historyAudioUrls = new Set(messages.map(m => m.meta_data?.audio_url).filter(Boolean))
|
||||
audioPollingRef.current.forEach((timer, key) => {
|
||||
if (!historyAudioUrls.has(key)) {
|
||||
clearInterval(timer)
|
||||
audioPollingRef.current.delete(key)
|
||||
}
|
||||
})
|
||||
messages.forEach(msg => {
|
||||
if (msg.role === 'assistant' && msg.meta_data?.audio_url && msg.meta_data?.audio_status === 'pending') {
|
||||
startAudioPolling(msg.meta_data.audio_url, msg.meta_data.audio_url)
|
||||
}
|
||||
})
|
||||
setChatList(messages.map(msg => {
|
||||
if (msg.role === 'assistant' && msg.meta_data?.audio_url && audioPollingRef.current.has(msg.meta_data.audio_url)) {
|
||||
return { ...msg, meta_data: { ...msg.meta_data, audio_status: 'pending' } }
|
||||
}
|
||||
return msg
|
||||
}))
|
||||
})
|
||||
} else {
|
||||
if (features?.opening_statement?.statement) {
|
||||
@@ -228,6 +251,28 @@ const Conversation: FC = () => {
|
||||
}))
|
||||
}, [audioStatusMap, chatList.length])
|
||||
|
||||
const startAudioPolling = (audioUrl: string, idToPoll: string) => {
|
||||
if (audioPollingRef.current.has(idToPoll)) return
|
||||
const fileId = audioUrl.split('/').pop()
|
||||
if (!fileId) return
|
||||
const timer = setInterval(() => {
|
||||
getFileStatusById(fileId)
|
||||
.then(res => {
|
||||
const { status } = res as { status: string }
|
||||
if (status && status !== 'pending') {
|
||||
setAudioStatusMap(prev => ({ ...prev, [idToPoll]: status }))
|
||||
clearInterval(audioPollingRef.current.get(idToPoll))
|
||||
audioPollingRef.current.delete(idToPoll)
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
clearInterval(audioPollingRef.current.get(idToPoll))
|
||||
audioPollingRef.current.delete(idToPoll)
|
||||
})
|
||||
}, 2000)
|
||||
audioPollingRef.current.set(idToPoll, timer)
|
||||
}
|
||||
|
||||
/** Send message and handle streaming response */
|
||||
const handleSend = (msg?: string) => {
|
||||
if (!token || !shareToken) return
|
||||
@@ -287,35 +332,8 @@ const Conversation: FC = () => {
|
||||
const { file_id } = item.data as { file_id?: string }
|
||||
const idToPoll = file_id || audio_url || ''
|
||||
const fileId = audio_url.split('/').pop()
|
||||
if (fileId && idToPoll && !audioPollingRef.current.has(idToPoll)) {
|
||||
|
||||
const timer = setInterval(() => {
|
||||
getFileStatusById(fileId)
|
||||
.then(res => {
|
||||
const { status } = res as { status: string }
|
||||
if (status && status !== 'pending') {
|
||||
setAudioStatusMap(prev => ({
|
||||
...prev,
|
||||
[idToPoll]: status
|
||||
}))
|
||||
clearInterval(audioPollingRef.current.get(idToPoll))
|
||||
audioPollingRef.current.delete(idToPoll)
|
||||
getHistory(true)
|
||||
if (currentConversationId && currentConversationId !== conversation_id) {
|
||||
setConversationId(currentConversationId)
|
||||
}
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
clearInterval(audioPollingRef.current.get(idToPoll))
|
||||
audioPollingRef.current.delete(idToPoll)
|
||||
getHistory(true)
|
||||
if (currentConversationId && currentConversationId !== conversation_id) {
|
||||
setConversationId(currentConversationId)
|
||||
}
|
||||
})
|
||||
}, 2000)
|
||||
audioPollingRef.current.set(idToPoll, timer)
|
||||
if (fileId && idToPoll) {
|
||||
startAudioPolling(audio_url, idToPoll)
|
||||
}
|
||||
} else {
|
||||
getHistory(true)
|
||||
@@ -327,6 +345,10 @@ const Conversation: FC = () => {
|
||||
updateAssistantMessage(content, audio_url, undefined, citations)
|
||||
}
|
||||
setLoading(false)
|
||||
getHistory(true)
|
||||
if (currentConversationId && currentConversationId !== conversation_id) {
|
||||
setConversationId(currentConversationId)
|
||||
}
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:37:12
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-04 10:05:39
|
||||
* @Last Modified time: 2026-03-27 22:22:18
|
||||
*/
|
||||
/**
|
||||
* Invite Register Page
|
||||
@@ -144,7 +144,7 @@ const InviteRegister: React.FC = () => {
|
||||
}).then((res) => {
|
||||
const response = res as LoginInfo;
|
||||
updateLoginInfo(response);
|
||||
navigate('/');
|
||||
navigate('/', { replace: true });
|
||||
}).finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
|
||||
@@ -243,6 +243,7 @@ const Prompt: FC = () => {
|
||||
<ModelSelect
|
||||
params={{ type: 'llm,chat' }}
|
||||
className={`rb:w-75! ${styles.select}`}
|
||||
variant="filled"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Button className="rb:border-none!" onClick={handleJump}>{t('prompt.history')}</Button>
|
||||
|
||||
@@ -116,13 +116,13 @@ const History: React.FC = () => {
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:leading-4.5">{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}</div>
|
||||
|
||||
<Space size={8}>
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/eye.svg')] rb:hover:bg-[url('src/assets/images/prompt/eye_bg.svg')]"
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/eye.svg')] rb:hover:bg-[url('@/assets/images/prompt/eye_bg.svg')]"
|
||||
onClick={() => handleClick('detail', item)}
|
||||
></div>
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/edit.svg')] rb:hover:bg-[url('src/assets/images/prompt/edit_bg.svg')]"
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/edit.svg')] rb:hover:bg-[url('@/assets/images/prompt/edit_bg.svg')]"
|
||||
onClick={() => handleClick('edit', item)}
|
||||
></div>
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/delete.svg')] rb:hover:bg-[url('src/assets/images/prompt/delete_hover.svg')]"
|
||||
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/delete.svg')] rb:hover:bg-[url('@/assets/images/prompt/delete_hover.svg')]"
|
||||
onClick={() => handleClick('delete', item)}
|
||||
></div>
|
||||
</Space>
|
||||
|
||||
@@ -65,8 +65,8 @@ const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => voi
|
||||
}, [id])
|
||||
|
||||
if (loading) {
|
||||
return <Flex align="center" justify="center" className="rb:w-full rb:h-full">
|
||||
<Spin tip={t('userMemory.communityLoadingTip')} size="large" className="rb:text-[#5B6167]! spin">
|
||||
return <Flex align="center" justify="center" className="rb:w-full rb:h-full spin">
|
||||
<Spin tip={t('userMemory.communityLoadingTip')} size="large" className="rb:text-[#5B6167]!">
|
||||
<div className="rb:w-64 rb:h-64" />
|
||||
</Spin>
|
||||
</Flex>
|
||||
|
||||
@@ -12,6 +12,7 @@ import { Row, Col, Progress, App, Table } from 'antd'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import {
|
||||
getForgetStats,
|
||||
getForgetPendingNodesUrl,
|
||||
} from '@/api/memory'
|
||||
import type { ForgetData } from '../types'
|
||||
import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard'
|
||||
@@ -19,6 +20,7 @@ import RecentTrendsLineCard from '../components/RecentTrendsLineCard'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
import StatusTag from '@/components/StatusTag'
|
||||
import ForgetRefreshModal from '../components/ForgetRefreshModal';
|
||||
import RbTable from '@/components/Table'
|
||||
|
||||
/** Maps node type keys to StatusTag colour presets for the pending-nodes table. */
|
||||
const statusTagColors: Record<string, 'success' | 'purple' | 'default' | 'warning' | 'error' | 'lightBlue'> = {
|
||||
@@ -191,7 +193,9 @@ const ForgetDetail = forwardRef((_props, ref) => {
|
||||
bodyClassName="rb:p-3! rb:py-0! rb:h-[calc(100%-54px)]"
|
||||
className="rb:h-full!"
|
||||
>
|
||||
<Table
|
||||
<RbTable
|
||||
apiUrl={getForgetPendingNodesUrl}
|
||||
apiParams={{ end_user_id: id }}
|
||||
rowKey='node_id'
|
||||
dataSource={data.pending_nodes ?? []}
|
||||
columns={[
|
||||
@@ -225,11 +229,6 @@ const ForgetDetail = forwardRef((_props, ref) => {
|
||||
render: (activation_value) => <span className="rb:text-[#5B6167]">{activation_value}</span>
|
||||
},
|
||||
]}
|
||||
pagination={{
|
||||
pageSize: 5,
|
||||
showQuickJumper: true,
|
||||
className: 'rb:mt-5! rb:mb-5.75!'
|
||||
}}
|
||||
className="table-header-has-bg"
|
||||
/>
|
||||
</RbCard>
|
||||
|
||||
Reference in New Issue
Block a user