Merge branch 'refs/heads/develop' into fix/memory_bug_fix

# Conflicts:
#	api/app/services/memory_agent_service.py
This commit is contained in:
lixinyue
2026-01-23 14:57:25 +08:00
38 changed files with 685 additions and 353 deletions

View File

@@ -55,8 +55,8 @@ class AgentRegistry:
"""
# 构建查询
stmt = select(AgentConfig).join(App).where(
AgentConfig.is_active == True,
App.is_active == True
AgentConfig.is_active.is_(True),
App.is_active.is_(True)
)
# 工作空间过滤(同工作空间或公开)

View File

@@ -758,7 +758,7 @@ class AppService:
)
# 构建查询条件
filters = [App.is_active == True]
filters = [App.is_active.is_(True)]
if type:
filters.append(App.type == type)
if visibility:
@@ -873,7 +873,7 @@ class AppService:
self._validate_workspace_access(app, workspace_id)
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
AgentConfig.updated_at.desc())
agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first()
now = datetime.datetime.now()
@@ -1204,7 +1204,7 @@ class AppService:
default_model_config_id = None
if app.type == AppType.AGENT:
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
AgentConfig.updated_at.desc())
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
@@ -1226,7 +1226,7 @@ class AppService:
select(MultiAgentConfig)
.where(
MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True
MultiAgentConfig.is_active.is_(True)
)
.order_by(MultiAgentConfig.updated_at.desc())
)
@@ -1380,7 +1380,7 @@ class AppService:
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active == True)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
return list(self.db.scalars(stmt).all())

View File

@@ -728,7 +728,7 @@ class DraftRunService:
select(ModelApiKey)
.where(
ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True
ModelApiKey.is_active.is_(True)
)
.order_by(ModelApiKey.priority.desc())
.limit(1)

View File

@@ -175,10 +175,9 @@ class MemoryAgentService:
"""
logger.info("Reading log file")
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
# Get log file path - use project root directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
summer = ''
@@ -217,9 +216,8 @@ class MemoryAgentService:
logger.info("Starting log content streaming")
# Get log file path - use project root directory
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream
@@ -431,13 +429,15 @@ class MemoryAgentService:
audit_logger = None
config_load_start = time.time()
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
@@ -578,6 +578,8 @@ class MemoryAgentService:
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -668,6 +670,8 @@ class MemoryAgentService:
"""
logger.info("Classifying message type")
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
@@ -681,10 +685,11 @@ class MemoryAgentService:
async def generate_summary_from_retrieve(
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: UUID,
config_id: str,
db: Session
) -> str:
"""
@@ -702,6 +707,18 @@ class MemoryAgentService:
Returns:
生成的答案文本
"""
if config_id is None:
try:
config_id = get_end_user_connected_config(end_user_id, db)
config_id = config_id.get('memory_config_id')
if config_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
@@ -727,7 +744,7 @@ class MemoryAgentService:
state=state,
history=history,
retrieve_info=retrieve_info,
template_name='Retrieve_Summary_prompt.jinja2',
template_name='direct_summary_prompt.jinja2',
operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse,
search_mode="1"
@@ -1075,9 +1092,8 @@ class MemoryAgentService:
logger.info("Starting log content streaming")
# Get log file path - use project root directory
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream
@@ -1175,7 +1191,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
import json
@@ -1184,7 +1200,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
@@ -1196,10 +1212,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"memory_config_id": memory_config_id
}
print(188*'*')
print(result)
print(188 * '*')
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -77,7 +77,10 @@ class MemoryAPIService:
)
# Verify end_user belongs to the workspace via App relationship
app = self.db.query(App).filter(App.id == end_user.app_id).first()
app = self.db.query(App).filter(
App.id == end_user.app_id,
App.is_active.is_(True)
).first()
if not app:
logger.warning(f"App not found for end_user: {end_user_id}")

View File

@@ -53,18 +53,28 @@ def get_workspace_end_users(
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主"""
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 查询应用ORM并转换为 Pydantic 模型
# 查询应用ORM
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
apps = [AppSchema.model_validate(h) for h in apps_orm]
app_ids = [app.id for app in apps]
end_users = []
for app_id in app_ids:
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list])
if not apps_orm:
business_logger.info("工作空间下没有应用")
return []
# 提取所有 app_id
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
from app.models.end_user_model import EndUser as EndUserModel
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).all()
# 转换为 Pydantic 模型(只在需要时转换)
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users
@@ -414,6 +424,67 @@ def get_current_user_total_chunk(
business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}")
raise
def get_users_total_chunk_batch(
end_user_ids: List[str],
db: Session,
current_user: User
) -> dict:
"""
批量获取多个用户的总chunk数性能优化版本
Args:
end_user_ids: 用户ID列表
db: 数据库会话
current_user: 当前用户
Returns:
字典key为end_user_idvalue为chunk总数
格式: {"user_id_1": 100, "user_id_2": 50, ...}
"""
business_logger.info(f"批量获取 {len(end_user_ids)} 个用户的总chunk数, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from sqlalchemy import func, case
if not end_user_ids:
return {}
# 构造所有文件名
file_names = [f"{user_id}.txt" for user_id in end_user_ids]
# 一次查询获取所有用户的chunk总数
# 使用 GROUP BY file_name 来分组统计
results = db.query(
Document.file_name,
func.sum(Document.chunk_num).label('total_chunk')
).filter(
Document.file_name.in_(file_names)
).group_by(
Document.file_name
).all()
# 构建结果字典
chunk_map = {}
for file_name, total_chunk in results:
# 从文件名中提取 end_user_id (去掉 .txt 后缀)
user_id = file_name.replace('.txt', '')
chunk_map[user_id] = int(total_chunk or 0)
# 对于没有记录的用户设置为0
for user_id in end_user_ids:
if user_id not in chunk_map:
chunk_map[user_id] = 0
business_logger.info(f"成功批量获取 {len(chunk_map)} 个用户的总chunk数")
return chunk_map
except Exception as e:
business_logger.error(f"批量获取用户总chunk数失败: {str(e)}")
raise
def get_rag_content(
end_user_id: str,
limit: int,

View File

@@ -38,7 +38,10 @@ class WorkspaceAppService:
Returns:
Dictionary containing detailed application information
"""
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
apps = self.db.query(App).filter(
App.workspace_id == workspace_id,
App.is_active.is_(True)
).all()
app_ids = [str(app.id) for app in apps]
apps_detailed_info = []

View File

@@ -12,7 +12,11 @@ from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db,
filter_tags_with_llm,
)
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.models.user_model import User
from app.repositories.memory_config_repository import MemoryConfigRepository
@@ -237,7 +241,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
ValueError: 当配置无效或参数缺失时
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
try:
# 发出初始进度事件
@@ -512,27 +517,79 @@ async def analytics_hot_memory_tags(
) -> List[Dict[str, Any]]:
"""
获取热门记忆标签按数量排序并返回前N个
优化策略:
1. 先从所有用户收集原始标签不调用LLM
2. 聚合并合并相同标签的频率
3. 排序后取前N个
4. 只调用一次LLM进行筛选
"""
workspace_id = current_user.current_workspace_id
# 获取更多标签供LLM筛选获取limit*4个标签
raw_limit = limit * 4
from app.services.memory_dashboard_service import get_workspace_end_users
end_users = get_workspace_end_users(db, workspace_id, current_user)
# 使用 asyncio.to_thread 避免阻塞事件循环
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
tags = []
for end_user in end_users:
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
if tag:
# 将每个用户的标签列表展平到总列表中
tags.extend(tag)
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
if not end_users:
return []
# 只返回前limit个
top_tags = sorted_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags]
# 步骤1: 收集所有用户的原始标签不调用LLM
connector = Neo4jConnector()
try:
all_raw_tags = []
for end_user in end_users:
raw_tags = await get_raw_tags_from_db(
connector,
str(end_user.id),
limit=raw_limit,
by_user=False
)
if raw_tags:
all_raw_tags.extend(raw_tags)
if not all_raw_tags:
return []
# 步骤2: 聚合相同标签的频率
tag_frequency_map = {}
for tag_name, frequency in all_raw_tags:
if tag_name in tag_frequency_map:
tag_frequency_map[tag_name] += frequency
else:
tag_frequency_map[tag_name] = frequency
# 步骤3: 按频率降序排序取前raw_limit个
sorted_tags = sorted(
tag_frequency_map.items(),
key=lambda x: x[1],
reverse=True
)[:raw_limit]
if not sorted_tags:
return []
# 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags]
# 使用第一个用户的group_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
# 步骤5: 根据LLM筛选结果构建最终列表保留频率
final_tags = []
for tag, freq in sorted_tags:
if tag in filtered_tag_names:
final_tags.append((tag, freq))
# 步骤6: 只返回前limit个
top_tags = final_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags]
finally:
await connector.close()
async def analytics_recent_activity_stats() -> Dict[str, Any]:

View File

@@ -2548,7 +2548,7 @@ class MultiAgentOrchestrator:
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == default_model_config_id,
ModelApiKey.is_active == True
ModelApiKey.is_active.is_(True)
).first()
if not api_key_config:
@@ -2705,7 +2705,7 @@ class MultiAgentOrchestrator:
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == default_model_config_id,
ModelApiKey.is_active == True
ModelApiKey.is_active.is_(True)
).first()
if not api_key_config:

View File

@@ -74,7 +74,7 @@ class MultiAgentService:
select(MultiAgentConfig)
.where(
MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True
MultiAgentConfig.is_active.is_(True)
)
.order_by(MultiAgentConfig.updated_at.desc())
).first()
@@ -144,7 +144,7 @@ class MultiAgentService:
select(MultiAgentConfig)
.where(
MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True
MultiAgentConfig.is_active.is_(True)
)
.order_by(MultiAgentConfig.updated_at.desc())
).first()

View File

@@ -168,7 +168,7 @@ class SharedChatService:
select(ModelApiKey)
.where(
ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True
ModelApiKey.is_active.is_(True)
)
.order_by(ModelApiKey.priority.desc())
.limit(1)
@@ -362,7 +362,7 @@ class SharedChatService:
select(ModelApiKey)
.where(
ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True
ModelApiKey.is_active.is_(True)
)
.order_by(ModelApiKey.priority.desc())
.limit(1)
@@ -598,7 +598,7 @@ class SharedChatService:
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True
MultiAgentConfig.is_active.is_(True)
).first()
if not multi_agent_config:
@@ -695,7 +695,7 @@ class SharedChatService:
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True
MultiAgentConfig.is_active.is_(True)
).first()
if not multi_agent_config:

View File

@@ -761,7 +761,10 @@ class WorkflowService:
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
app = self.db.query(App).filter(App.id == app_id).first()
app = self.db.query(App).filter(
App.id == app_id,
App.is_active.is_(True)
).first()
if not app:
raise BusinessException(
code=BizCode.NOT_FOUND,