Merge remote-tracking branch 'origin/release/v0.2.9' into develop

This commit is contained in:
Ke Sun
2026-03-31 19:16:13 +08:00
55 changed files with 1482 additions and 570 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
import json
import logging
import os
import threading
from typing import Dict, Any, Optional
import redis.asyncio as redis
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
)
aio_redis = redis.StrictRedis(connection_pool=pool)
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
# Thread-local storage for connection pools.
# Each thread (and each forked process) gets its own pool to avoid
# "Future attached to a different loop" errors in Celery --pool=threads
# and stale connections after fork in --pool=prefork.
_thread_local = threading.local()
def get_thread_safe_redis() -> redis.StrictRedis:
"""Return a Redis client whose connection pool is bound to the current
thread, process **and** event loop.
The pool is recreated when:
- The PID changes (fork, Celery --pool=prefork)
- The thread has no pool yet (Celery --pool=threads)
- The previously-cached event loop has been closed (Celery tasks call
``_shutdown_loop_gracefully`` which closes the loop after each run)
"""
current_pid = os.getpid()
cached_loop = getattr(_thread_local, "loop", None)
loop_stale = cached_loop is not None and cached_loop.is_closed()
if not hasattr(_thread_local, "pool") \
or getattr(_thread_local, "pid", None) != current_pid \
or loop_stale:
_thread_local.pid = current_pid
# Python 3.10+: get_event_loop() raises RuntimeError in threads
# where no loop has been set yet (e.g. Celery --pool=threads).
try:
_thread_local.loop = asyncio.get_event_loop()
except RuntimeError:
_thread_local.loop = None
_thread_local.pool = ConnectionPool.from_url(
_REDIS_URL,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=5,
health_check_interval=30,
)
return redis.StrictRedis(connection_pool=_thread_local.pool)
async def get_redis_connection():
"""获取Redis连接"""
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
else:
# 设置永久键值
await aio_redis.set(key, val)
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
from app.aioRedis import get_thread_safe_redis
logger = logging.getLogger(__name__)
@@ -68,7 +68,7 @@ class ActivityStatsCache:
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
await get_thread_safe_redis().set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
@@ -90,7 +90,7 @@ class ActivityStatsCache:
"""
try:
key = cls._get_key(workspace_id)
value = await aio_redis.get(key)
value = await get_thread_safe_redis().get(key)
if value:
payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}")
@@ -116,7 +116,7 @@ class ActivityStatsCache:
"""
try:
key = cls._get_key(workspace_id)
result = await aio_redis.delete(key)
result = await get_thread_safe_redis().delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
from app.aioRedis import get_thread_safe_redis
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class InterestMemoryCache:
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
await get_thread_safe_redis().set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
@@ -86,7 +86,7 @@ class InterestMemoryCache:
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
value = await get_thread_safe_redis().get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
@@ -114,7 +114,7 @@ class InterestMemoryCache:
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
result = await get_thread_safe_redis().delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:

View File

@@ -57,7 +57,6 @@ def list_apps(
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
api_key: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
@@ -66,7 +65,7 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
- 当提供 api_key 参数时,查找该 API Key 关联的应用
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
"""
from sqlalchemy import select as sa_select
from app.models.api_key_model import ApiKey
@@ -74,23 +73,34 @@ def list_apps(
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
if api_key:
matched_id = db.execute(
sa_select(ApiKey.resource_id).where(
ApiKey.workspace_id == workspace_id,
ApiKey.api_key == api_key,
ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
ids = str(matched_id) if matched_id else ""
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
if search:
search = search.strip()
# 尝试作为 API Key 精确匹配API Key 通常较长)
if len(search) >= 10:
matched_id = db.execute(
sa_select(ApiKey.resource_id).where(
ApiKey.workspace_id == workspace_id,
ApiKey.api_key == search,
ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
if matched_id:
# 找到 API Key直接返回关联的应用
ids = str(matched_id)
# 当 ids 存在且不为 None 时,根据 ids 获取应用
# 当 ids 存在时,根据 ids 获取应用(不分页)
if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
if app_ids:
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
# 返回标准分页格式
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
return success(data=PageData(page=meta, items=items))
# ids 为空时,返回空列表
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
return success(data=PageData(page=meta, items=[]))
# 正常分页查询
items_orm, total = app_service.list_apps(

View File

@@ -3,17 +3,16 @@ import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.conversation_model import Conversation, Message
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService
from app.services.app_log_service import AppLogService
router = APIRouter(prefix="/apps", tags=["App Logs"])
logger = get_business_logger()
@@ -25,52 +24,35 @@ def list_app_logs(
app_id: uuid.UUID,
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
user_id: Optional[str] = None,
is_draft: Optional[bool] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""查看应用下所有会话记录(分页)
- 支持按 user_id 筛选
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
- 按最新更新时间倒序排列
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
service = AppService(db)
service.get_app(app_id, workspace_id)
app_service = AppService(db)
app_service.get_app(app_id, workspace_id)
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True),
# 使用 Service 层查询
log_service = AppLogService(db)
conversations, total = log_service.list_conversations(
app_id=app_id,
workspace_id=workspace_id,
page=page,
pagesize=pagesize,
is_draft=is_draft
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
total = int(db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(db.scalars(stmt).all())
items = [AppLogConversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
logger.info(
"查询应用日志会话列表",
extra={"app_id": str(app_id), "total": total, "page": page}
)
return success(data=PageData(page=meta, items=items))
@@ -86,44 +68,22 @@ def get_app_log_detail(
- 返回会话基本信息 + 所有消息(按时间正序)
- 消息 meta_data 包含模型名、token 用量等信息
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
service = AppService(db)
service.get_app(app_id, workspace_id)
app_service = AppService(db)
app_service.get_app(app_id, workspace_id)
# 查询会话(确保属于该应用和工作空间)
conversation = db.scalars(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True),
)
).first()
if not conversation:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("会话", str(conversation_id))
# 查询消息(按时间正序)
messages = list(db.scalars(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
).all())
detail = AppLogConversationDetail.model_validate(conversation)
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
# 使用 Service 层查询
log_service = AppLogService(db)
conversation = log_service.get_conversation_detail(
app_id=app_id,
conversation_id=conversation_id,
workspace_id=workspace_id
)
detail = AppLogConversationDetail.model_validate(conversation)
return success(data=detail)

View File

@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
ForgettingCurveRequest,
ForgettingCurveResponse,
ForgettingCurvePoint,
PendingNodesResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.get("/pending-nodes", response_model=ApiResponse)
async def get_pending_nodes(
end_user_id: str,
page: int = 1,
pagesize: int = 10,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
此接口独立分页,与 /stats 接口分离。
Args:
end_user_id: 组ID即 end_user_id必填
page: 页码从1开始默认1
pagesize: 每页数量默认10
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含待遗忘节点列表和分页信息的响应
Examples:
- 第1页每页10条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
- 第2页每页20条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
Notes:
- page 从1开始pagesize 必须大于0
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证 end_user_id 必填
if not end_user_id:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
# 验证分页参数
if page < 1:
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
if pagesize < 1:
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
)
try:
# 调用服务层获取待遗忘节点列表
result = await forget_service.get_pending_nodes(
db=db,
end_user_id=end_user_id,
config_id=config_id,
page=page,
pagesize=pagesize
)
# 构建响应
response_data = PendingNodesResponse(**result)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve(
request: ForgettingCurveRequest,

View File

@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.workflow_service import WorkflowService
from app.models.file_metadata_model import FileMetadata
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
@@ -259,8 +260,41 @@ def get_conversation(
conv_service = ConversationService(db)
messages = conv_service.get_messages(conversation_id)
# 构建响应
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
file_ids = []
message_file_id_map = {}
# 第一次遍历:解析 audio_url收集所有有效的 file_id
for idx, m in enumerate(messages):
if m.role == "assistant" and m.meta_data:
audio_url = m.meta_data.get("audio_url")
if not audio_url:
continue
try:
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
except (ValueError, IndexError):
# audio_url 无法解析为 UUID标记为 unknown
m.meta_data["audio_status"] = "unknown"
continue
file_ids.append(file_id)
message_file_id_map[idx] = file_id
# 批量查询所有相关的 FileMetadata
file_status_map = {}
if file_ids:
file_metas = (
db.query(FileMetadata)
.filter(FileMetadata.id.in_(set(file_ids)))
.all()
)
file_status_map = {fm.id: fm.status for fm in file_metas}
# 第二次遍历:将查询结果映射回消息
for idx, file_id in message_file_id_map.items():
m = messages[idx]
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
conv_dict["messages"] = [
conversation_schema.Message.model_validate(m) for m in messages
]
@@ -320,6 +354,16 @@ async def chat(
other_id=other_id,
original_user_id=user_id
)
# Only extract and set memory_config_id when the end user doesn't have one yet
if not new_end_user.memory_config_id:
from app.services.memory_config_service import MemoryConfigService
memory_config_service = MemoryConfigService(db)
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
if memory_config_id:
new_end_user.memory_config_id = memory_config_id
db.commit()
db.refresh(new_end_user)
end_user_id = str(new_end_user.id)
# appid = share.app_id

View File

@@ -91,7 +91,7 @@ async def chat(
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id
workspace_id = app.workspace_id
workspace_id = api_key_auth.workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id,

View File

@@ -111,6 +111,18 @@ def get_current_user_info(
break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
# 设置权限:如果用户来自 SSO Source则使用该 Source 的 permissions否则返回 "all" 表示拥有所有权限
if current_user.external_source:
from premium.sso.models import SSOSource
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
if source and source.permissions:
result_schema.permissions = source.permissions
else:
result_schema.permissions = []
else:
result_schema.permissions = ["all"]
return success(data=result_schema, msg=t("users.info.get_success"))
@@ -135,7 +147,6 @@ def get_tenant_superusers(
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
@router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id(
user_id: uuid.UUID,

View File

@@ -151,11 +151,6 @@ async def write(
# Step 3: Save all data to Neo4j database
step_start = time.time()
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
@@ -279,5 +274,21 @@ async def write(
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client):
try:
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
if underlying is None:
continue
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
inner = getattr(underlying, '_model', underlying)
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
http_client = getattr(inner, 'async_client', None)
if http_client is not None and hasattr(http_client, 'aclose'):
await http_client.aclose()
except Exception:
pass
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
type=type_
)
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
"""

View File

@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
logger = logging.getLogger(__name__)
def message_has_files(message: "ConversationMessage") -> bool:
"""检查消息是否包含文件。
Args:
message: 待检查的消息对象
Returns:
bool: 如果消息包含文件则返回 True否则返回 False
"""
return message.files and len(message.files) > 0
class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -128,7 +140,7 @@ class SemanticPruner:
1. 空消息
2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢""同学你好""明白了"
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢""同学你好""明白了"
5. 纯表情/标点
"""
t = message.msg.strip()
@@ -482,6 +494,11 @@ class SemanticPruner:
"""
to_delete_ids: set = set()
for m in msgs:
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
continue
# 填充检测优先:先判断是否为填充,再看 LLM 保护
if self._is_filler_message(m):
to_delete_ids.add(id(m))
@@ -549,6 +566,11 @@ class SemanticPruner:
to_delete_ids: set = set()
for m in msgs:
msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
continue
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
if self._is_filler_message(m):
@@ -801,6 +823,12 @@ class SemanticPruner:
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与分类
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
llm_protected_msgs.append((idx, m)) # 放入保护列表
continue
if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m))

View File

@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
list[StatementEntityEdge],
list[EntityEntityEdge],
list[PerceptualEdge],
dict
list[DialogData]
]:
"""
运行完整的知识提取流水线(优化版:并行执行)
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
statement_entity_edges,
entity_entity_edges,
dialog_data_list,
dedup_details,
) = await self._run_dedup_and_write_summary(
dialogue_nodes,
chunk_nodes,
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
dialog_data_list,
)
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
if not is_pilot_run:
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
logger.info(f"知识提取流水线运行完成({mode_str}")
return (
dialogue_nodes,
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
else:
first_alias = current_aliases[0].strip() if current_aliases else ""
if first_alias:
# 确保 first_alias 不是占位名称
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo(
end_user_id=end_user_uuid,
other_name=first_alias,
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'}
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
这个方法直接返回 LLM 提取的别名列表,不做任何修改
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I"
第一个别名将被用作 other_name。
Args:
entity_nodes: 实体节点列表
Returns:
别名列表(保持 LLM 提取的原始顺序)
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称
"""
USER_NAMES = {'用户', '', 'User', 'I'}
for entity in entity_nodes:
if getattr(entity, 'name', '').strip() in USER_NAMES:
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or []
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
return aliases
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
return filtered
return []
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
aliases = result[0].get('aliases') or []
if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return aliases
return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
def _resolve_other_name(
self,
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
决策规则:
- 为空 → 用本次对话第一个别名
- 为空或为占位名称 → 用本次对话第一个别名
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
- 否则 → 保持不变(返回 None
注意:返回值不允许是占位名称("用户""""User""I"
"""
if not current or not current.strip():
return current_aliases[0].strip() if current_aliases else None
# 当前值为空或为占位名称时,需要更新
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
if current not in neo4j_aliases:
return neo4j_aliases[0].strip() if neo4j_aliases else None
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
return None
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
list[DialogData],
dict
]:
"""
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
statement_chunk_edges,
dedup_statement_entity_edges,
dedup_entity_entity_edges,
dialog_data_list,
dedup_details,
)
final_entity_nodes = dedup_entity_nodes
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
final_entity_entity_edges = dedup_entity_entity_edges
else:
# 正式模式:执行完整的两阶段去重
result_tuple = await dedup_layers_and_merge_and_return(
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = await dedup_layers_and_merge_and_return(
dialogue_nodes,
chunk_nodes,
statement_nodes,
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
llm_client=self.llm_client,
)
# 解包返回值
(
_,
_,
_,
final_entity_nodes,
_,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = result_tuple
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
result_tuple = (
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dialog_data_list,
dedup_details,
)
logger.info(
f"去重后: {len(final_entity_nodes)} 个实体节点, "
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "

View File

@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %}
- 用户实体的 name 字段:使用 "用户" 或 "我"
- 用户的真实姓名:放入 aliases
- **🚨 禁止将 "用户"、"我" 放入 aliases 中aliases 只能包含用户的真实姓名、昵称等**
- 示例:
* "我叫李明" → name="用户", aliases=["李明"]
* ❌ 错误aliases=["用户", "李明"]"用户"不是真实姓名,禁止放入 aliases
* ❌ 错误aliases=["我", "李明"]"我"不是真实姓名,禁止放入 aliases
{% else %}
- User entity name field: use "User" or "I"
- User's real name: put in aliases
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
- Examples:
* "I'm John" → name="User", aliases=["John"]
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
{% endif %}

View File

@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
access_key_id: str,
access_key_secret: str,
bucket_name: str,
connect_timeout: int = 30,
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
):
"""
Initialize the OSSStorage backend.
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
access_key_id: The Aliyun access key ID.
access_key_secret: The Aliyun access key secret.
bucket_name: The name of the OSS bucket.
connect_timeout: Connection timeout in seconds (default: 30).
multipart_threshold: File size threshold for multipart upload (default: 10MB).
Raises:
StorageConfigError: If any required configuration is missing.
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
self.endpoint = endpoint
self.bucket_name = bucket_name
self.multipart_threshold = multipart_threshold
try:
auth = oss2.Auth(access_key_id, access_key_secret)
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
# 设置超时和重试
self.bucket = oss2.Bucket(
auth,
endpoint,
bucket_name,
connect_timeout=connect_timeout
)
logger.info(
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
)
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
if content_type:
headers["Content-Type"] = content_type
self.bucket.put_object(file_key, content, headers=headers if headers else None)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB per part
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers if headers else None)
logger.info(f"File uploaded to OSS successfully: {file_key}")
return file_key
except OssError as e:
logger.error(f"OSS error uploading file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to upload file to OSS: {e.message}",
message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to upload file to OSS: {e}",
message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
) -> int:
"""Upload from async stream to OSS. Returns total bytes written."""
buf = io.BytesIO()
headers = {"Content-Type": content_type} if content_type else None
upload_id = None
try:
# 收集流数据
total_size = 0
async for chunk in stream:
if not chunk:
continue
buf.write(chunk)
total_size += len(chunk)
content = buf.getvalue()
headers = {"Content-Type": content_type} if content_type else None
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
return len(content)
if not content:
raise StorageUploadError(
message="Empty stream content",
file_key=file_key,
)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
return total_size
except OssError as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"OSS error stream uploading file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e.message}",
message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e}",
message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
finally:
buf.close()
async def download(self, file_key: str) -> bytes:
"""
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
except OssError as e:
logger.error(f"OSS error downloading file {file_key}: {e}")
raise StorageDownloadError(
message=f"Failed to download file from OSS: {e.message}",
message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to download file from OSS {file_key}: {e}")
raise StorageDownloadError(
message=f"Failed to download file from OSS: {e}",
message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
except OssError as e:
logger.error(f"OSS error deleting file {file_key}: {e}")
raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e.message}",
message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e}",
message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status != 200:
if response.status not in (200, 202):
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,7 +190,9 @@ class SimpleMCPClient:
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
# MCP SSE 协议POST 请求返回 200 或 202 均为正常
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
if response.status not in (200, 202):
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -205,7 +207,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
if response.status not in (200, 202):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):

View File

@@ -1,5 +1,6 @@
import os
import subprocess
from app.repositories.neo4j.create_indexes import create_all_indexes
from contextlib import asynccontextmanager
from fastapi import FastAPI, APIRouter
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
logger.warning(f"加载预定义模型时出错: {str(e)}")
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes()
logger.info("应用程序启动完成")
yield
# 应用关闭事件
logger.info("应用程序正在关闭")

View File

@@ -19,9 +19,12 @@ class User(Base):
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
# SSO 外部关联字段
external_id = Column(String(100), nullable=True) # 外部用户ID
external_id = Column(String(100), nullable=True) # 外部用户 ID
external_source = Column(String(50), nullable=True) # 来源系统
# 用户联系方式
phone = Column(String(50), nullable=True) # 用户电话
# 用户语言偏好
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文

View File

@@ -199,6 +199,96 @@ class ConversationRepository:
)
return conversations, total
def list_app_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
is_draft: Optional[bool] = None,
page: int = 1,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""
查询应用日志会话列表(带分页和过滤)
Args:
app_id: 应用 ID
workspace_id: 工作空间 ID
is_draft: 是否草稿会话None 表示不过滤)
page: 页码(从 1 开始)
pagesize: 每页数量
Returns:
Tuple[List[Conversation], int]: (会话列表,总数)
"""
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True)
)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# Calculate total number of records
total = int(self.db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
# Apply pagination
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
logger.info(
"Listed app conversations successfully",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"returned": len(conversations),
"total": total
}
)
return conversations, total
def get_conversation_for_app_log(
self,
conversation_id: uuid.UUID,
app_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Conversation:
"""
查询应用日志的会话详情
Args:
conversation_id: 会话 ID
app_id: 应用 ID
workspace_id: 工作空间 ID
Returns:
Conversation: 会话对象
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(f"Fetching conversation for app log: {conversation_id}")
stmt = select(Conversation).where(
Conversation.id == conversation_id,
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True)
)
conversation = self.db.scalars(stmt).first()
if not conversation:
logger.warning(f"Conversation not found: {conversation_id}")
raise ResourceNotFoundException("会话", str(conversation_id))
logger.info(f"Conversation fetched successfully: {conversation_id}")
return conversation
def soft_delete_conversation_by_conversation_id(
self,
conversation_id: uuid.UUID,
@@ -290,6 +380,34 @@ class MessageRepository:
self.db.add(message)
return message
def get_messages_by_conversation(
self,
conversation_id: uuid.UUID
) -> list[Message]:
"""
查询会话的所有消息(按时间正序)
Args:
conversation_id: 会话 ID
Returns:
List[Message]: 消息列表
"""
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
messages = list(self.db.scalars(stmt).all())
logger.info(
"Fetched messages for conversation",
extra={
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return messages
def get_message_by_conversation_id(
self,
conversation_id: uuid.UUID,

View File

@@ -132,6 +132,82 @@ class EndUserRepository:
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
raise
def get_or_create_end_user_with_config(
self,
app_id: Optional[uuid.UUID],
workspace_id: uuid.UUID,
other_id: str,
memory_config_id: Optional[uuid.UUID] = None,
other_name: Optional[str] = None
) -> EndUser:
"""获取或创建终端用户,并在单次事务中关联记忆配置。
与 get_or_create_end_user 类似,但额外支持在创建/获取时
一并设置 memory_config_id避免多次提交。
Args:
app_id: 应用ID可为 None
workspace_id: 工作空间ID
other_id: 第三方ID
memory_config_id: 记忆配置ID可选仅在用户尚无配置时设置
other_name: 用户名称(用于创建 EndUserInfo
Returns:
EndUser: 终端用户对象(已关联记忆配置)
"""
try:
end_user = (
self.db.query(EndUser)
.filter(
EndUser.workspace_id == workspace_id,
EndUser.other_id == other_id
)
.order_by(EndUser.created_at.asc())
.first()
)
if end_user:
db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}")
if app_id is not None:
end_user.app_id = app_id
if memory_config_id and not end_user.memory_config_id:
end_user.memory_config_id = memory_config_id
self.db.commit()
self.db.refresh(end_user)
return end_user
# 创建新用户
end_user = EndUser(
app_id=app_id,
workspace_id=workspace_id,
other_id=other_id,
memory_config_id=memory_config_id,
)
self.db.add(end_user)
self.db.flush()
end_user_info = EndUserInfo(
end_user_id=end_user.id,
other_name=other_name or "",
aliases=[],
meta_data={}
)
self.db.add(end_user_info)
self.db.commit()
self.db.refresh(end_user)
db_logger.info(
f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, "
f"memory_config_id={memory_config_id}"
)
return end_user
except Exception as e:
self.db.rollback()
db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}")
raise
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据ID获取终端用户用于缓存操作
@@ -515,6 +591,51 @@ class EndUserRepository:
)
raise
def batch_update_memory_config_id_by_app(
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
int: 更新的终端用户数量
Raises:
Exception: 数据库操作失败时抛出
"""
try:
from sqlalchemy import update
stmt = (
update(EndUser)
.where(EndUser.app_id == app_id)
.values(memory_config_id=memory_config_id)
)
result = self.db.execute(stmt)
self.db.commit()
updated_count = result.rowcount
db_logger.info(
f"批量更新终端用户记忆配置: app_id={app_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
return updated_count
except Exception as e:
self.db.rollback()
db_logger.error(
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
f"memory_config_id={memory_config_id}, error={str(e)}"
)
raise
def count_by_memory_config_id(
self,
memory_config_id: uuid.UUID

View File

@@ -1,62 +1,47 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Full-Text Indexes (for keyword search)")
print("=" * 70)
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: statementsFulltext")
""")
# # 创建 Dialogues 索引
# await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
# """)
# 创建 Entities 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: entitiesFulltext")
""")
# 创建 Chunks 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: chunksFulltext")
""")
# 创建 MemorySummary 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: summariesFulltext")
""")
# 创建 Community 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: communitiesFulltext")
print("\nFull-text indexes created successfully with BM25 support.")
except Exception as e:
print(f"✗ Error creating full-text indexes: {e}")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
@@ -65,12 +50,7 @@ async def create_vector_indexes():
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Vector Indexes (for embedding search)")
print("=" * 70)
print("Note: Adjust vector.dimensions if using different embedding model")
print(" Current setting: 1024 dimensions (for bge-m3)")
print()
# Statement embedding index
await connector.execute_query("""
@@ -82,7 +62,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: statement_embedding_index")
# Chunk embedding index
await connector.execute_query("""
@@ -94,7 +74,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: chunk_embedding_index")
# Entity name embedding index
await connector.execute_query("""
@@ -106,7 +86,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: entity_embedding_index")
# Memory summary embedding index
await connector.execute_query("""
@@ -118,8 +98,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: summary_embedding_index")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
@@ -129,8 +108,7 @@ async def create_vector_indexes():
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: community_summary_embedding_index")
""")
# Dialogue embedding index (optional)
await connector.execute_query("""
@@ -142,91 +120,15 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: dialogue_embedding_index")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
FOR (c:Community)
ON c.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: community_summary_embedding_index")
print("\nVector indexes created successfully!")
print("\nExpected performance improvement:")
print(" Before: ~1.4s for embedding search")
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
except Exception as e:
print(f"✗ Error creating vector indexes: {e}")
finally:
await connector.close()
async def create_config_id_indexes():
"""Create indexes on config_id fields for improved query performance.
These indexes enable fast filtering of nodes by configuration ID,
which is essential for configuration isolation and multi-tenant scenarios.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Config ID Indexes")
print("=" * 70)
# Dialogue.config_id index
await connector.execute_query("""
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
FOR (d:Dialogue) ON (d.config_id)
""")
print("✓ Created: dialogue_config_id_index")
# Statement.config_id index
await connector.execute_query("""
CREATE INDEX statement_config_id_index IF NOT EXISTS
FOR (s:Statement) ON (s.config_id)
""")
print("✓ Created: statement_config_id_index")
# ExtractedEntity.config_id index
await connector.execute_query("""
CREATE INDEX entity_config_id_index IF NOT EXISTS
FOR (e:ExtractedEntity) ON (e.config_id)
""")
print("✓ Created: entity_config_id_index")
# MemorySummary.config_id index
await connector.execute_query("""
CREATE INDEX summary_config_id_index IF NOT EXISTS
FOR (m:MemorySummary) ON (m.config_id)
""")
print("✓ Created: summary_config_id_index")
print("\nConfig ID indexes created successfully!")
print("These indexes enable fast filtering by configuration ID.")
except Exception as e:
print(f"✗ Error creating config_id indexes: {e}")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Unique Constraints")
print("=" * 70)
try:
# Dialogue.id unique
await connector.execute_query(
"""
@@ -234,8 +136,7 @@ async def create_unique_constraints():
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
"""
)
print("✓ Created: dialog_id_unique")
# Statement.id unique
await connector.execute_query(
"""
@@ -243,8 +144,7 @@ async def create_unique_constraints():
FOR (s:Statement) REQUIRE s.id IS UNIQUE
"""
)
print("✓ Created: statement_id_unique")
# Chunk.id unique
await connector.execute_query(
"""
@@ -252,112 +152,13 @@ async def create_unique_constraints():
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
"""
)
print("✓ Created: chunk_id_unique")
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
except Exception as e:
print(f"✗ Error creating unique constraints: {e}")
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
print("\n" + "=" * 70)
print("Neo4j Index & Constraint Setup")
print("=" * 70)
print("This will create:")
print(" 1. Full-text indexes (for keyword/BM25 search)")
print(" 2. Vector indexes (for embedding similarity search)")
print(" 3. Config ID indexes (for configuration isolation)")
print(" 4. Unique constraints (for data integrity)")
print("=" * 70)
await create_fulltext_indexes()
await create_vector_indexes()
await create_config_id_indexes()
await create_unique_constraints()
print("\n" + "=" * 70)
print("✓ All indexes and constraints created successfully!")
print("=" * 70)
print("\nTo verify, run in Neo4j Browser:")
print(" SHOW INDEXES")
print(" SHOW CONSTRAINTS")
print()
async def check_indexes():
"""Check what indexes currently exist."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Checking Existing Indexes")
print("=" * 70)
query = "SHOW INDEXES"
result = await connector.execute_query(query)
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
for idx in fulltext_indexes:
print(f"{idx.get('name')}")
print(f"\nVector indexes: {len(vector_indexes)}")
for idx in vector_indexes:
print(f"{idx.get('name')}")
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
for idx in range_indexes:
print(f"{idx.get('name')}")
if not vector_indexes:
print("\n⚠️ WARNING: No vector indexes found!")
print(" Embedding search will be VERY SLOW (~1.4s)")
print(" Run: python create_indexes.py")
# Check for config_id indexes
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
if len(config_id_indexes) < 4:
print("\n⚠️ WARNING: Not all config_id indexes found!")
print(f" Expected 4, found {len(config_id_indexes)}")
print(" Run: python create_indexes.py config_id")
print("=" * 70)
finally:
await connector.close()
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
command = sys.argv[1]
if command == "check":
asyncio.run(check_indexes())
elif command == "fulltext":
asyncio.run(create_fulltext_indexes())
elif command == "vector":
asyncio.run(create_vector_indexes())
elif command == "config_id":
asyncio.run(create_config_id_indexes())
elif command == "constraints":
asyncio.run(create_unique_constraints())
else:
print(f"Unknown command: {command}")
print("\nUsage:")
print(" python create_indexes.py # Create all indexes")
print(" python create_indexes.py check # Check existing indexes")
print(" python create_indexes.py fulltext # Create only full-text indexes")
print(" python create_indexes.py vector # Create only vector indexes")
print(" python create_indexes.py config_id # Create only config_id indexes")
print(" python create_indexes.py constraints # Create only constraints")
else:
asyncio.run(create_all_indexes())

View File

@@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
WITH e, score
UNION
MATCH (e:ExtractedEntity)
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
AND e.aliases IS NOT NULL
AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q))
WITH e,
WITH collect({entity: e, score: score}) AS fulltextResults
OPTIONAL MATCH (ae:ExtractedEntity)
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
AND ae.aliases IS NOT NULL
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
WITH fulltextResults, collect(ae) AS aliasEntities
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
CASE
WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
ELSE 0.8
END AS score
END
}]) AS row
WITH row.entity AS e, row.score AS score
WITH DISTINCT e, MAX(score) AS score
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)

View File

@@ -158,22 +158,26 @@ class UserRepository:
raise
def get_users_by_tenant(
self,
tenant_id: uuid.UUID,
skip: int = 0,
self,
tenant_id: uuid.UUID,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> List[User]:
"""获取租户下的用户列表"""
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
try:
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if is_superuser is not None:
query = query.filter(User.is_superuser == is_superuser)
if search:
query = query.filter(
or_(
@@ -181,7 +185,7 @@ class UserRepository:
User.email.ilike(f"%{search}%")
)
)
users = query.offset(skip).limit(limit).all()
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
return users
@@ -190,18 +194,22 @@ class UserRepository:
raise
def count_users_by_tenant(
self,
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户下的用户数量"""
try:
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if is_superuser is not None:
query = query.filter(User.is_superuser == is_superuser)
if search:
query = query.filter(
or_(
@@ -209,7 +217,7 @@ class UserRepository:
User.email.ilike(f"%{search}%")
)
)
return query.scalar()
except Exception as e:
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")

View File

@@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel):
# 记忆配置
memory: MemoryConfig = Field(
default_factory=lambda: MemoryConfig(enabled=True),
default_factory=lambda: MemoryConfig(enabled=False),
description="对话历史记忆配置"
)

View File

@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
last_access_time: int = Field(..., description="最后访问时间Unix时间戳")
class PageInfo(BaseModel):
"""分页信息模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
page: int = Field(..., description="当前页码从1开始")
pagesize: int = Field(..., description="每页数量")
total: int = Field(..., description="总记录数")
hasnext: bool = Field(..., description="是否有下一页")
class PendingNodesResponse(BaseModel):
"""待遗忘节点列表响应模型(独立分页接口)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
page: PageInfo = Field(..., description="分页信息")
class ForgettingStatsResponse(BaseModel):
"""遗忘引擎统计信息响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -1,6 +1,6 @@
from dataclasses import field
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
from typing import Optional
from typing import Optional, List
import datetime
import uuid
@@ -20,6 +20,7 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel):
username: Optional[str] = None
email: Optional[EmailStr] = None
phone: Optional[str] = None
is_active: Optional[bool] = None
is_superuser: Optional[bool] = None
@@ -85,6 +86,8 @@ class User(UserBase):
current_workspace_name: Optional[str] = None
role: Optional[WorkspaceRole] = None
preferred_language: Optional[str] = "zh" # 用户语言偏好
phone: Optional[str] = None # 用户电话
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
# 将 datetime 转换为毫秒时间戳
@validator("created_at", pre=True)

View File

@@ -141,13 +141,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables)
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(
@@ -378,13 +378,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables)
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(

View File

@@ -0,0 +1,128 @@
"""应用日志服务层"""
import uuid
from typing import Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.models.conversation_model import Conversation, Message
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
logger = get_business_logger()
class AppLogService:
"""应用日志服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_repository = ConversationRepository(db)
self.message_repository = MessageRepository(db)
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
is_draft: Optional[bool] = None,
) -> Tuple[list[Conversation], int]:
"""
查询应用日志会话列表
Args:
app_id: 应用 ID
workspace_id: 工作空间 ID
page: 页码(从 1 开始)
pagesize: 每页数量
is_draft: 是否草稿会话None 表示不过滤)
Returns:
Tuple[list[Conversation], int]: (会话列表,总数)
"""
logger.info(
"查询应用日志会话列表",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"page": page,
"pagesize": pagesize,
"is_draft": is_draft
}
)
# 使用 Repository 查询
conversations, total = self.conversation_repository.list_app_conversations(
app_id=app_id,
workspace_id=workspace_id,
is_draft=is_draft,
page=page,
pagesize=pagesize
)
logger.info(
"查询应用日志会话列表成功",
extra={
"app_id": str(app_id),
"total": total,
"returned": len(conversations)
}
)
return conversations, total
def get_conversation_detail(
self,
app_id: uuid.UUID,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Conversation:
"""
查询会话详情(包含消息)
Args:
app_id: 应用 ID
conversation_id: 会话 ID
workspace_id: 工作空间 ID
Returns:
Conversation: 包含消息的会话对象
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)
}
)
# 查询会话
conversation = self.conversation_repository.get_conversation_for_app_log(
conversation_id=conversation_id,
app_id=app_id,
workspace_id=workspace_id
)
# 查询消息(按时间正序)
messages = self.message_repository.get_messages_by_conversation(
conversation_id=conversation_id
)
# 将消息附加到会话对象
conversation.messages = messages
logger.info(
"查询应用日志会话详情成功",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return conversation

View File

@@ -1084,7 +1084,6 @@ class AppService:
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned
exists = self.db.query(
@@ -1096,7 +1095,6 @@ class AppService:
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned
@@ -1684,15 +1682,15 @@ class AppService:
return config.config_id
def _update_endusers_memory_config_by_workspace(
def _update_endusers_memory_config_by_app(
self,
workspace_id: uuid.UUID,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
workspace_id: 工作空间ID
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
@@ -1701,8 +1699,8 @@ class AppService:
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id_by_workspace(
workspace_id=workspace_id,
updated_count = repo.batch_update_memory_config_id_by_app(
app_id=app_id,
memory_config_id=memory_config_id
)
@@ -1753,12 +1751,16 @@ class AppService:
miss_params = []
if agent_cfg.default_model_config_id is None:
miss_params.append("model config")
miss_params.append("模型配置")
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
miss_params.append("memory config")
miss_params.append("记忆配置")
if miss_params:
raise BusinessException(f"{', '.join(miss_params)} is required")
raise BusinessException(
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
BizCode.CONFIG_MISSING,
context={"missing_params": miss_params},
)
config = {
"system_prompt": agent_cfg.system_prompt,
@@ -1877,8 +1879,8 @@ class AppService:
if memory_config_id:
app = self.db.query(App).filter(App.id == app_id).first()
if app:
updated_count = self._update_endusers_memory_config_by_workspace(
app.workspace_id, memory_config_id
updated_count = self._update_endusers_memory_config_by_app(
app_id, memory_config_id
)
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
@@ -2014,7 +2016,7 @@ class AppService:
if memory_config_id:
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
logger.info(
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"

View File

@@ -214,7 +214,7 @@ class ConversationService:
conversation.message_count += 1
if conversation.message_count == 1 and role == "user":
if conversation.message_count <= 2 and role == "user":
conversation.title = (
content[:50] + ("..." if len(content) > 50 else "")
)

View File

@@ -448,15 +448,16 @@ class AgentRunService:
features_config: Dict[str, Any],
is_new_conversation: bool,
variables: Optional[Dict[str, Any]] = None
) -> Optional[str]:
) -> tuple[Any, Any]:
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
if not is_new_conversation:
return None
return None, None
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return None
return None, None
statement = opening["statement"]
suggested_questions = opening["suggested_questions"]
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
if variables:
@@ -464,7 +465,7 @@ class AgentRunService:
placeholder = f"{{{{{var_name}}}}}"
statement = statement.replace(placeholder, str(var_value))
return statement
return statement, suggested_questions
@staticmethod
def _filter_citations(
@@ -598,13 +599,16 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
opening_statement=opening
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
@@ -839,14 +843,17 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
sub_agent=sub_agent,
opening_statement=opening
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
@@ -1050,7 +1057,8 @@ class AgentRunService:
workspace_id: uuid.UUID,
user_id: Optional[str],
sub_agent: bool = False,
opening_statement: Optional[str] = None
opening_statement: Optional[str] = None,
suggested_questions: Optional[List[str]] = None
) -> str:
"""确保会话存在(创建或验证)
@@ -1061,6 +1069,7 @@ class AgentRunService:
user_id: 用户ID
sub_agent: 是否为子代理
opening_statement: 开场白(新会话时作为第一条消息写入)
suggested_questions: 预设问题列表
Returns:
str: 会话ID
@@ -1104,7 +1113,7 @@ class AgentRunService:
conversation_id=uuid.UUID(new_conv_id),
role="assistant",
content=opening_statement,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
logger.debug(f"已保存开场白到会话 {new_conv_id}")

View File

@@ -204,30 +204,35 @@ class MemoryForgetService:
end_user_id: str,
forgetting_threshold: float,
min_days_since_access: int,
limit: int = 20
) -> list[Dict[str, Any]]:
page: Optional[int] = None,
pagesize: Optional[int] = None
) -> Dict[str, Any]:
"""
获取待遗忘节点列表
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
Args:
connector: Neo4j 连接器
end_user_id: 组ID
forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数
limit: 返回节点数量限制
page: 页码可选从1开始
pagesize: 每页数量(可选)
Returns:
list: 待遗忘节点列表
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息(分页时)
"""
from datetime import timedelta
# 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区)
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
query = """
# 基础查询(用于获取总数)
count_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
@@ -235,10 +240,22 @@ class MemoryForgetService:
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
RETURN count(n) as total
"""
# 数据查询
data_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
elementId(n) as node_id,
labels(n)[0] as node_type,
CASE
CASE
WHEN n:Statement THEN n.statement
WHEN n:ExtractedEntity THEN n.name
WHEN n:MemorySummary THEN n.content
@@ -247,18 +264,32 @@ class MemoryForgetService:
n.activation_value as activation_value,
n.last_access_time as last_access_time
ORDER BY n.activation_value ASC
LIMIT $limit
"""
# 如果启用分页,添加 SKIP 和 LIMIT
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
data_query += " SKIP $skip LIMIT $limit"
params = {
'end_user_id': end_user_id,
'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str,
'limit': limit
'min_access_time_str': min_access_time_str
}
results = await connector.execute_query(query, **params)
# 获取总数(分页时需要)
total = 0
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
count_results = await connector.execute_query(count_query, **params)
if count_results:
total = count_results[0]['total']
# 添加分页参数
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
params['skip'] = (page - 1) * pagesize
params['limit'] = pagesize
results = await connector.execute_query(data_query, **params)
pending_nodes = []
for result in results:
# 将节点类型标签转换为小写
@@ -267,7 +298,7 @@ class MemoryForgetService:
node_type_label = 'entity'
elif node_type_label == 'memorysummary':
node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
@@ -278,7 +309,7 @@ class MemoryForgetService:
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else:
last_access_timestamp = 0
pending_nodes.append({
'node_id': str(result['node_id']),
'node_type': node_type_label,
@@ -286,8 +317,20 @@ class MemoryForgetService:
'activation_value': result['activation_value'],
'last_access_time': last_access_timestamp
})
return pending_nodes
# 构建返回结果
result: Dict[str, Any] = {'items': pending_nodes}
# 如果启用分页,添加分页信息
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
result['page'] = {
'page': page,
'pagesize': pagesize,
'total': total,
'hasnext': (page * pagesize) < total
}
return result
async def trigger_forgetting_cycle(
self,
@@ -636,7 +679,7 @@ class MemoryForgetService:
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 获取待遗忘节点列表前20个满足遗忘条件的节点
# 获取待遗忘节点列表
pending_nodes = []
try:
if end_user_id:
@@ -652,8 +695,7 @@ class MemoryForgetService:
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
limit=20
min_days_since_access=int(min_days)
)
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
@@ -661,24 +703,79 @@ class MemoryForgetService:
except Exception as e:
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 构建统计信息
# 构建统计信息(不包含 pending_nodes已分离到独立接口
stats = {
'activation_metrics': activation_metrics,
'node_distribution': node_distribution,
'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp() * 1000)
}
api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
f"trend_days={len(recent_trends)}"
)
return stats
async def get_pending_nodes(
self,
db: Session,
end_user_id: str,
config_id: Optional[UUID] = None,
page: int = 1,
pagesize: int = 10
) -> Dict[str, Any]:
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
Args:
db: 数据库会话
end_user_id: 组ID必填
config_id: 配置ID可选用于获取遗忘阈值
page: 页码从1开始默认1
pagesize: 每页数量默认10
Returns:
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
connector = forgetting_scheduler.connector
forgetting_threshold = config['forgetting_threshold']
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
api_logger.warning(
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
)
min_days = 7
# 调用内部方法获取分页数据
pending_nodes_result = await self._get_pending_forgetting_nodes(
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
page=page,
pagesize=pagesize
)
api_logger.info(
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
)
return pending_nodes_result
async def get_forgetting_curve(
self,
db: Session,

View File

@@ -12,6 +12,9 @@ import base64
import csv
import io
import json
import re
import olefile
import struct
import zipfile
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -602,31 +605,75 @@ class MultimodalService:
try:
word_file = io.BytesIO(file_content)
doc = Document(word_file)
return '\n'.join(p.text for p in doc.paragraphs)
text_lines = []
for p in doc.paragraphs:
text = p.text.strip()
if text:
text_lines.append(text)
for table in doc.tables:
for row in table.rows:
for cell in row.cells:
text = cell.text.strip()
if text:
text_lines.append(text)
full_text = "\n".join(text_lines)
return full_text.strip() or "[docx 文件无文本内容]"
except Exception as e:
logger.error(f"提取 docx 文本失败: {e}")
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
return f"[docx 提取失败: {str(e)}]"
# 旧版 .docOLE2 格式)
# 旧版 .docOLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
try:
import olefile
ole = olefile.OleFileIO(io.BytesIO(file_content))
if not ole.exists('WordDocument'):
return "[doc 提取失败: 未找到 WordDocument 流]"
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
stream = ole.openstream('WordDocument').read()
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
# 简单提取:过滤出可打印字符段
try:
text = stream.decode('utf-16-le', errors='ignore')
except Exception:
text = stream.decode('latin-1', errors='ignore')
# 过滤控制字符,保留可打印内容
import re
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
text = re.sub(r' +', ' ', text).strip()
word_stream = ole.openstream('WordDocument').read()
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
table_stream = ole.openstream(table_name).read()
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
clx = table_stream[fc_clx: fc_clx + lcb_clx]
# 解析 CLX找到 PlcPcdpiece table
i, plc_pcd = 0, None
while i < len(clx):
clxt = clx[i]
if clxt == 0x01:
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
elif clxt == 0x02:
cb = struct.unpack_from('<I', clx, i + 1)[0]
plc_pcd = clx[i + 5: i + 5 + cb]
break
else:
break
if plc_pcd is None:
raise ValueError("PlcPcd not found")
# PlcPcd: (n+1) 个 CP4字节+ n 个 PCD8字节
n_pieces = (len(plc_pcd) - 4) // 12
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
parts = []
for k in range(n_pieces):
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
is_ansi = bool(fc_value & 0x40000000)
fc = fc_value & 0x3FFFFFFF
char_count = cp_array[k + 1] - cp_array[k]
if is_ansi:
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
else:
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
ole.close()
return text
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
return result.strip()
except Exception as e:
logger.error(f"提取 doc 文本失败: {e}")
return f"[doc 提取失败: {str(e)}]"

View File

@@ -138,7 +138,7 @@ class TenantService:
except Exception as e:
business_logger.error(f"删除租户失败: {str(e)}")
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(f"删除租户失败{str(e)}", code=BizCode.DB_ERROR)
# 租户用户管理
def get_tenant_users(
@@ -147,6 +147,7 @@ class TenantService:
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> List[UserModel]:
"""获取租户下的用户列表"""
@@ -155,6 +156,7 @@ class TenantService:
skip=skip,
limit=limit,
is_active=is_active,
is_superuser=is_superuser,
search=search
)
@@ -162,12 +164,14 @@ class TenantService:
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户下的用户数量"""
return self.user_repo.count_users_by_tenant(
tenant_id=tenant_id,
is_active=is_active,
is_superuser=is_superuser,
search=search
)

View File

@@ -472,6 +472,21 @@ class UserMemoryService:
# 定义允许更新的字段白名单
allowed_fields = {'other_name', 'aliases', 'meta_data'}
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
_user_placeholder_names = {'用户', '', 'User', 'I'}
# 过滤 other_name不允许设置为占位名称
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
del update_data['other_name']
# 过滤 aliases移除占位名称和非字符串值
if 'aliases' in update_data and update_data['aliases']:
update_data['aliases'] = [
a for a in update_data['aliases']
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
]
# 检查是否更新了 aliases 字段
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases

View File

@@ -1,5 +1,4 @@
import asyncio
import hashlib
import os
import re
import shutil
@@ -38,12 +37,10 @@ from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
from app.services.memory_forget_service import MemoryForgetService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock
from app.utils.redis_lock import RedisFairLock
logger = get_logger(__name__)
@@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
def set_asyncio_event_loop():
"""Set the asyncio event loop for the current thread."""
"""Ensure an open asyncio event loop exists for the current thread.
Reuses the existing event loop if one is available and still open.
Creates and installs a new event loop only when the current one is
closed or missing (e.g. after ``_shutdown_loop_gracefully``).
"""
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
@@ -116,6 +118,30 @@ def set_asyncio_event_loop():
return loop
def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop):
"""Gracefully shutdown pending async generators and tasks on the event loop.
This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
by giving pending aclose() coroutines a chance to run before the loop is discarded.
Note: This only tears down the given loop. Callers that need a fresh event
loop afterwards should use ``set_asyncio_event_loop()`` explicitly.
"""
try:
# Cancel and collect all remaining tasks
all_tasks = asyncio.all_tasks(loop)
if all_tasks:
for task in all_tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True))
# Shutdown async generators (triggers __aclose__ on httpx clients etc.)
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception:
pass
finally:
loop.close()
@celery_app.task(name="tasks.process_item")
def process_item(item: dict):
"""
@@ -1148,8 +1174,28 @@ def write_message_task(
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
redis_client = get_sync_redis_client()
lock = None
if redis_client is not None:
lock = RedisFairLock(
key=f"memory_write:{end_user_id}",
redis_client=redis_client,
expire=600,
timeout=3600,
auto_renewal=True,
)
if not lock.acquire():
logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}")
return {
"status": "SKIPPED",
"error": "acquire lock timeout",
"end_user_id": end_user_id,
"config_id": str(config_id),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
try:
# 尝试获取现有事件循环,如果不存在则创建新的
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
@@ -1158,7 +1204,6 @@ def write_message_task(
logger.info(f"[CELERY WRITE] Task completed successfully "
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try:
_r = get_sync_redis_client()
if _r is not None:
@@ -1199,6 +1244,15 @@ def write_message_task(
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
finally:
if lock is not None:
try:
lock.release()
except Exception as e:
logger.warning(f"[CELERY WRITE] 释放锁失败: {e}")
# Gracefully shutdown the event loop to prevent
# 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
_shutdown_loop_gracefully(loop)
# unused task
@@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
# unused task

View File

@@ -1,6 +1,7 @@
import redis
import uuid
import time
import threading
UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
@@ -10,45 +11,136 @@ else
end
"""
RENEW_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
"""
class RedisLock:
CLEANUP_DEAD_HEAD_SCRIPT = """
local queue_key = KEYS[1]
local lock_key = KEYS[2]
local first = redis.call("lindex", queue_key, 0)
if not first then
return 0
end
if redis.call("exists", lock_key) == 1 then
return 0
end
redis.call("lpop", queue_key)
return 1
"""
SAFE_RELEASE_QUEUE_SCRIPT = """
local queue_key = KEYS[1]
local value = ARGV[1]
local first = redis.call("lindex", queue_key, 0)
if first == value then
redis.call("lpop", queue_key)
return 1
end
return 0
"""
def _ensure_str(val):
"""统一将 Redis 返回值转为 str兼容 decode_responses=True/False"""
if val is None:
return None
if isinstance(val, bytes):
return val.decode("utf-8")
return str(val)
class RedisFairLock:
def __init__(
self,
key: str,
redis_client: redis.StrictRedis,
expire: int = 60,
retry_interval: float = 0.1,
timeout: float = 30
expire: int = 30,
retry_interval: float = 0.05,
timeout: float = 600,
auto_renewal: bool = True
):
self.key = key
self.expire = expire
self.queue_key = f"{key}:queue"
self.value = str(uuid.uuid4())
self._locked = False
self.expire = expire
self.retry_interval = retry_interval
self.timeout = timeout
self.redis_client = redis_client
self.redis = redis_client
self._locked = False
self.auto_renewal = auto_renewal
self._renew_thread = None
self._stop_renew = threading.Event()
def acquire(self) -> bool:
def acquire(self):
start = time.time()
self.redis.rpush(self.queue_key, self.value)
while True:
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
if ok:
self._locked = True
return True
if time.time() - start >= self.timeout:
first = _ensure_str(self.redis.lindex(self.queue_key, 0))
if first == self.value:
ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
if ok:
self._locked = True
if self.auto_renewal:
self._start_renewal()
return True
if first:
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
if time.time() - start > self.timeout:
self.redis.lrem(self.queue_key, 0, self.value)
return False
time.sleep(self.retry_interval)
def _renewal_loop(self):
while not self._stop_renew.is_set():
time.sleep(self.expire / 3)
if self._stop_renew.is_set():
break
self.redis.eval(
RENEW_SCRIPT,
1,
self.key,
self.value,
str(self.expire)
)
def _start_renewal(self):
self._stop_renew = threading.Event()
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
self._renew_thread.start()
def _stop_renewal(self):
self._stop_renew.set()
if self._renew_thread:
self._renew_thread.join(timeout=1)
def release(self):
if not self._locked:
return
self.redis_client.eval(
UNLOCK_SCRIPT,
1,
self.key,
self.value
)
if self.auto_renewal:
self._stop_renewal()
self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
self._locked = False
def __enter__(self):
@@ -59,3 +151,4 @@ class RedisLock:
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()

View File

@@ -0,0 +1,30 @@
"""202603271515
Revision ID: 4e89970f9e7c
Revises: 6b8a461148ff
Create Date: 2026-03-27 15:12:27.518344
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '4e89970f9e7c'
down_revision: Union[str, None] = '6b8a461148ff'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'phone')
# ### end Alembic commands ###

View File

@@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => {
export const getForgetStats = (end_user_id: string) => {
return request.get(`/memory/forget-memory/stats`, { end_user_id })
}
// 获取带遗忘节点列表
export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes'
// Implicit Memory - Preferences
export const getImplicitPreferences = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)

View File

@@ -37,11 +37,11 @@ const ChatContent: FC<ChatContentProps> = ({
const prevDataLengthRef = useRef(data.length);
const isScrolledToBottomRef = useRef(true);
const audioRef = useRef<HTMLAudioElement | null>(null)
const [playingIndex, setPlayingIndex] = useState<number | null>(null)
const [playingIndex, setPlayingIndex] = useState<string | null>(null)
const handlePlay = (index: number, audio_url: string, audio_status?: string) => {
if (audio_status !== 'completed' && !audio_status) return
if (playingIndex === index) {
const handlePlay = (audio_url: string, audio_status?: string) => {
if (audio_status !== 'completed' && typeof audio_status === 'string') return
if (playingIndex === audio_url) {
audioRef.current?.pause()
setPlayingIndex(null)
return
@@ -52,7 +52,7 @@ const ChatContent: FC<ChatContentProps> = ({
const audio = new Audio(audio_url)
audioRef.current = audio
audio.play()
setPlayingIndex(index)
setPlayingIndex(audio_url)
audio.onended = () => setPlayingIndex(null)
}
@@ -79,12 +79,16 @@ const ChatContent: FC<ChatContentProps> = ({
}
};
}, []);
// Auto-scroll to bottom when data changes to show latest messages
// When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom
// When data array length changes, auto-scroll to bottom
// If already scrolled to bottom, will auto-scroll to bottom
useEffect(() => {
if (playingIndex && !data.some(item => item.meta_data?.audio_url === playingIndex)) {
audioRef.current?.pause()
setPlayingIndex(null)
}
setTimeout(() => {
if (scrollContainerRef.current) {
// Auto-scroll if data length changed OR user is currently at bottom
@@ -204,16 +208,16 @@ const ChatContent: FC<ChatContentProps> = ({
{item.meta_data?.audio_url && <>
<Divider className="rb:my-3!" />
<Space size={12} className="rb:pb-2 rb:pl-1">
{playingIndex !== index && item.meta_data?.audio_status === 'pending'
{playingIndex !== item.meta_data?.audio_url && item.meta_data?.audio_status === 'pending'
? <Spin />
: playingIndex !== index
: playingIndex !== item.meta_data?.audio_url
? <SoundOutlined className={clsx("rb:cursor-pointer rb:size-5.5", {
'rb:text-[#FF5D34]': item.meta_data?.audio_status === 'error',
'rb:hover:text-[#155EEF]!': !item.meta_data?.audio_status || !['pending', 'error'].includes(item.meta_data?.audio_status)
})} onClick={() => handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} />
})} onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} />
: <div
className="rb:size-5.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/audio_ing.gif')]"
onClick={() => handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)}
onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)}
/>
}
</Space>

View File

@@ -24,16 +24,17 @@ interface BodyWrapperProps {
/** Whether to show loading state */
loading?: boolean
/** Whether the content is empty */
empty: boolean
empty: boolean;
className?: string;
}
const BodyWrapper: FC<BodyWrapperProps> = ({ children, loading = false, empty }) => {
const BodyWrapper: FC<BodyWrapperProps> = ({ children, loading = false, empty, className = 'rb:max-h-[calc(100%-48px)]!' }) => {
// Show loading spinner while data is being fetched
if (loading) {
return <PageLoading />
return <PageLoading className={className} />
}
// Show empty state when no data is available
if (!loading && empty) {
return <PageEmpty />
return <PageEmpty className={className} />
}
// Render actual content when data is loaded and available
return children

View File

@@ -128,6 +128,7 @@ const Menu: FC<{
/** Filter menus based on user role and source */
useEffect(() => {
if (!user) return
let menuList: MenuItem[] = []
if (user.role === 'member' && source === 'space') {
@@ -136,7 +137,7 @@ const Menu: FC<{
menuList = allMenus[source] || []
}
const noAuthList = ['user', 'pricing'].filter(vo => !user.permissions?.includes(vo) && !user.permissions?.includes('all'))
const noAuthList = ['user', 'pricing'].filter(vo => (Array.isArray(user.permissions) && !user.permissions?.includes(vo) && !user.permissions?.includes('all')) || !Array.isArray(user.permissions))
if (noAuthList && !noAuthList?.includes('all')) {
const filterMenus = (list: MenuItem[]): MenuItem[] =>{

View File

@@ -1509,7 +1509,7 @@ export const en = {
EPISODIC_MEMORY: 'Episodic Memory',
FORGET_MEMORY: 'Forget Memory',
endUserProfile: 'Profile',
endUserProfile: 'Permanent Memory',
editEndUserProfile: 'Edit',
other_name: 'Name',
position: 'Position',
@@ -1828,6 +1828,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
memoryTipTitle: 'Are you sure you want to enable conversation memory? Conversations will be saved to the memory store.',
stopAudioRecorder: 'Stop Recording',
startAudioRecorder: 'Start Recording',
citations: 'Citations',
},
login: {
title: 'Red Bear Memory Science',

View File

@@ -1507,7 +1507,7 @@ export const zh = {
EPISODIC_MEMORY: '情景记忆',
FORGET_MEMORY: '遗忘记忆',
endUserProfile: '核心档案',
endUserProfile: '永久记忆',
editEndUserProfile: '编辑',
other_name: '名称',
position: '职位',

View File

@@ -377,9 +377,10 @@ body {
.ant-input-filled,
.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector {
background-color: #FFFFFF;
border-color: #FFFFFF;
}
.ant-input-filled:hover,
.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector {
.ant-select-filled:not(.ant-select-disabled):not(.ant-select-customize-input):not(.ant-pagination-size-changer):hover .ant-select-selector {
background-color: #FFFFFF;
border-color: #171719;
}
@@ -402,7 +403,7 @@ body {
color: #FF5D34;
}
.spin.ant-spin-nested-loading .ant-spin-container::after {
.spin .ant-spin-nested-loading .ant-spin-container::after {
background: transparent;
}
.upload-block,

View File

@@ -34,16 +34,16 @@ const Statistics: FC = () => {
className: 'rb:text-[#212332]'
},
{
title: t('user.createTime'),
title: t('application.created_at'),
dataIndex: 'created_at',
key: 'created_at',
render: (createdAt: string) => formatDateTime(createdAt, 'YYYY-MM-DD HH:mm:ss'),
},
{
title: t('user.lastLoginTime'),
dataIndex: 'last_login_at',
key: 'last_login_at',
render: (lastLoginAt: string) => lastLoginAt ? formatDateTime(lastLoginAt, 'YYYY-MM-DD HH:mm:ss') : '-',
title: t('common.updated_at'),
dataIndex: 'updated_at',
key: 'updated_at',
render: (updatedAt: string) => updatedAt ? formatDateTime(updatedAt, 'YYYY-MM-DD HH:mm:ss') : '-',
},
{
title: t('common.operation'),

View File

@@ -220,31 +220,31 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
/>
<Popover content={t('workflow.clear')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/clear.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/clear.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={clear}
></div>
</Popover>
<Popover content={t('workflow.addvariable')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/variable.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/variable.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={addvariable}
></div>
</Popover>
<Popover content={t('workflow.run')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/run.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/run.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={run}
></div>
</Popover>
<Popover content={t('workflow.save')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/save.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/save.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={save}
></div>
</Popover>
<Popover content={t('common.return')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/return.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/return.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={goToApplication}
></div>
</Popover>

View File

@@ -49,7 +49,7 @@ const FeaturesConfig: FC<FeaturesConfigProps> = ({
?
<Popover content={t('application.features')} classNames={{ body: 'rb:py-0.5! rb:px-1! rb:rounded-[6px]! rb:text-[12px]!' }}>
<div
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('src/assets/images/workflow/features.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
className="rb:cursor-pointer rb:size-7.5 rb:border rb:border-[#EBEBEB] rb:hover:bg-[#F6F6F6] rb:rounded-[10px] rb:bg-[url('@/assets/images/workflow/features.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat"
onClick={handleFeaturesConfig}
></div>
</Popover>

View File

@@ -216,7 +216,7 @@ const ApplicationManagement: React.FC = () => {
'rb:text-[#155EEF]': key === 'type',
})}>
{key === 'source' && item.is_shared
? t('application.shared')
? item.source_workspace_name
: key === 'source' && !item.is_shared
? t('application.configuration')
: key === 'created_at'

View File

@@ -64,6 +64,13 @@ const Conversation: FC = () => {
const [config, setConfig] = useState<Record<string, any>>({})
const [audioStatusMap, setAudioStatusMap] = useState<Record<string, string>>({})
useEffect(() => {
return () => {
audioPollingRef.current.forEach((timer) => clearInterval(timer))
audioPollingRef.current.clear()
}
}, [])
useEffect(() => {
const shareToken = localStorage.getItem(`shareToken_${token}`)
setShareToken(shareToken)
@@ -144,13 +151,29 @@ const Conversation: FC = () => {
}
useEffect(() => {
audioPollingRef.current.forEach((timer) => clearInterval(timer))
audioPollingRef.current.clear()
if (conversation_id) {
getConversationDetail(token as string, conversation_id)
.then(res => {
const response = res as { messages: ChatItem[] }
setChatList(response?.messages || [])
const messages = response?.messages || []
const historyAudioUrls = new Set(messages.map(m => m.meta_data?.audio_url).filter(Boolean))
audioPollingRef.current.forEach((timer, key) => {
if (!historyAudioUrls.has(key)) {
clearInterval(timer)
audioPollingRef.current.delete(key)
}
})
messages.forEach(msg => {
if (msg.role === 'assistant' && msg.meta_data?.audio_url && msg.meta_data?.audio_status === 'pending') {
startAudioPolling(msg.meta_data.audio_url, msg.meta_data.audio_url)
}
})
setChatList(messages.map(msg => {
if (msg.role === 'assistant' && msg.meta_data?.audio_url && audioPollingRef.current.has(msg.meta_data.audio_url)) {
return { ...msg, meta_data: { ...msg.meta_data, audio_status: 'pending' } }
}
return msg
}))
})
} else {
if (features?.opening_statement?.statement) {
@@ -228,6 +251,28 @@ const Conversation: FC = () => {
}))
}, [audioStatusMap, chatList.length])
const startAudioPolling = (audioUrl: string, idToPoll: string) => {
if (audioPollingRef.current.has(idToPoll)) return
const fileId = audioUrl.split('/').pop()
if (!fileId) return
const timer = setInterval(() => {
getFileStatusById(fileId)
.then(res => {
const { status } = res as { status: string }
if (status && status !== 'pending') {
setAudioStatusMap(prev => ({ ...prev, [idToPoll]: status }))
clearInterval(audioPollingRef.current.get(idToPoll))
audioPollingRef.current.delete(idToPoll)
}
})
.catch(() => {
clearInterval(audioPollingRef.current.get(idToPoll))
audioPollingRef.current.delete(idToPoll)
})
}, 2000)
audioPollingRef.current.set(idToPoll, timer)
}
/** Send message and handle streaming response */
const handleSend = (msg?: string) => {
if (!token || !shareToken) return
@@ -287,35 +332,8 @@ const Conversation: FC = () => {
const { file_id } = item.data as { file_id?: string }
const idToPoll = file_id || audio_url || ''
const fileId = audio_url.split('/').pop()
if (fileId && idToPoll && !audioPollingRef.current.has(idToPoll)) {
const timer = setInterval(() => {
getFileStatusById(fileId)
.then(res => {
const { status } = res as { status: string }
if (status && status !== 'pending') {
setAudioStatusMap(prev => ({
...prev,
[idToPoll]: status
}))
clearInterval(audioPollingRef.current.get(idToPoll))
audioPollingRef.current.delete(idToPoll)
getHistory(true)
if (currentConversationId && currentConversationId !== conversation_id) {
setConversationId(currentConversationId)
}
}
})
.catch(() => {
clearInterval(audioPollingRef.current.get(idToPoll))
audioPollingRef.current.delete(idToPoll)
getHistory(true)
if (currentConversationId && currentConversationId !== conversation_id) {
setConversationId(currentConversationId)
}
})
}, 2000)
audioPollingRef.current.set(idToPoll, timer)
if (fileId && idToPoll) {
startAudioPolling(audio_url, idToPoll)
}
} else {
getHistory(true)
@@ -327,6 +345,10 @@ const Conversation: FC = () => {
updateAssistantMessage(content, audio_url, undefined, citations)
}
setLoading(false)
getHistory(true)
if (currentConversationId && currentConversationId !== conversation_id) {
setConversationId(currentConversationId)
}
break
}
})

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:37:12
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 10:05:39
* @Last Modified time: 2026-03-27 22:22:18
*/
/**
* Invite Register Page
@@ -144,7 +144,7 @@ const InviteRegister: React.FC = () => {
}).then((res) => {
const response = res as LoginInfo;
updateLoginInfo(response);
navigate('/');
navigate('/', { replace: true });
}).finally(() => {
setLoading(false);
});

View File

@@ -243,6 +243,7 @@ const Prompt: FC = () => {
<ModelSelect
params={{ type: 'llm,chat' }}
className={`rb:w-75! ${styles.select}`}
variant="filled"
/>
</Form.Item>
<Button className="rb:border-none!" onClick={handleJump}>{t('prompt.history')}</Button>

View File

@@ -116,13 +116,13 @@ const History: React.FC = () => {
<div className="rb:text-[12px] rb:text-[#5B6167] rb:leading-4.5">{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}</div>
<Space size={8}>
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/eye.svg')] rb:hover:bg-[url('src/assets/images/prompt/eye_bg.svg')]"
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/eye.svg')] rb:hover:bg-[url('@/assets/images/prompt/eye_bg.svg')]"
onClick={() => handleClick('detail', item)}
></div>
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/edit.svg')] rb:hover:bg-[url('src/assets/images/prompt/edit_bg.svg')]"
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/edit.svg')] rb:hover:bg-[url('@/assets/images/prompt/edit_bg.svg')]"
onClick={() => handleClick('edit', item)}
></div>
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('src/assets/images/prompt/delete.svg')] rb:hover:bg-[url('src/assets/images/prompt/delete_hover.svg')]"
<div className="rb:size-4.5 rb:bg-cover rb:bg-[url('@/assets/images/prompt/delete.svg')] rb:hover:bg-[url('@/assets/images/prompt/delete_hover.svg')]"
onClick={() => handleClick('delete', item)}
></div>
</Space>

View File

@@ -65,8 +65,8 @@ const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => voi
}, [id])
if (loading) {
return <Flex align="center" justify="center" className="rb:w-full rb:h-full">
<Spin tip={t('userMemory.communityLoadingTip')} size="large" className="rb:text-[#5B6167]! spin">
return <Flex align="center" justify="center" className="rb:w-full rb:h-full spin">
<Spin tip={t('userMemory.communityLoadingTip')} size="large" className="rb:text-[#5B6167]!">
<div className="rb:w-64 rb:h-64" />
</Spin>
</Flex>

View File

@@ -12,6 +12,7 @@ import { Row, Col, Progress, App, Table } from 'antd'
import RbCard from '@/components/RbCard/Card'
import {
getForgetStats,
getForgetPendingNodesUrl,
} from '@/api/memory'
import type { ForgetData } from '../types'
import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard'
@@ -19,6 +20,7 @@ import RecentTrendsLineCard from '../components/RecentTrendsLineCard'
import { formatDateTime } from '@/utils/format'
import StatusTag from '@/components/StatusTag'
import ForgetRefreshModal from '../components/ForgetRefreshModal';
import RbTable from '@/components/Table'
/** Maps node type keys to StatusTag colour presets for the pending-nodes table. */
const statusTagColors: Record<string, 'success' | 'purple' | 'default' | 'warning' | 'error' | 'lightBlue'> = {
@@ -191,7 +193,9 @@ const ForgetDetail = forwardRef((_props, ref) => {
bodyClassName="rb:p-3! rb:py-0! rb:h-[calc(100%-54px)]"
className="rb:h-full!"
>
<Table
<RbTable
apiUrl={getForgetPendingNodesUrl}
apiParams={{ end_user_id: id }}
rowKey='node_id'
dataSource={data.pending_nodes ?? []}
columns={[
@@ -225,11 +229,6 @@ const ForgetDetail = forwardRef((_props, ref) => {
render: (activation_value) => <span className="rb:text-[#5B6167]">{activation_value}</span>
},
]}
pagination={{
pageSize: 5,
showQuickJumper: true,
className: 'rb:mt-5! rb:mb-5.75!'
}}
className="table-header-has-bg"
/>
</RbCard>