feat(memory): add workspace_id fallback support for memory config resolution

- Add workspace_id fallback parameter to memory config loading across all services
- Update hot_memory_tags.py to pass workspace_id when resolving memory configuration
- Enhance emotion_analytics_service.py to support workspace_id as fallback for config resolution
- Improve implicit_memory_service.py with workspace_id fallback in config loading
- Update memory_agent_service.py to handle workspace_id resolution and add refactoring TODO
- Enhance preference_analysis.jinja2 prompt with critical guidance on supporting_evidence extraction
- Add validation to check both config_id and workspace_id before raising configuration errors
- Improve error handling and logging for memory configuration resolution across services
- This enables more flexible memory configuration resolution when config_id is unavailable
This commit is contained in:
Ke Sun
2026-02-06 14:48:58 +08:00
parent 7a78f15a90
commit 5c10f11681
8 changed files with 195 additions and 134 deletions

View File

@@ -39,16 +39,20 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if not config_id:
if not config_id and not workspace_id:
raise ValueError(
f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration."
)
# Use the config_id to get the proper LLM client
# Use the config_id to get the proper LLM client with workspace fallback
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
if not memory_config.llm_model_id:
raise ValueError(

View File

@@ -16,6 +16,7 @@ Summary {{ loop.index }}:
3. DO NOT use long phrases - use short nouns or noun phrases
4. Only include preferences with confidence_score >= 0.3
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
6. **CRITICAL: supporting_evidence must be DIRECT QUOTES or paraphrases from the user's actual statements. DO NOT reference summary numbers (e.g., "Summary 1", "摘要1"). DO NOT describe what the summary contains. Extract the actual user behavior or statement as evidence.**
## Output Format
{
@@ -38,6 +39,16 @@ Summary {{ loop.index }}:
]
}
## BAD supporting_evidence examples (DO NOT do this):
- "Summary 1西湖为核心景区" ❌
- "摘要2中提到喜欢咖啡" ❌
- "Based on Summary 3" ❌
## GOOD supporting_evidence examples:
- "去过西湖断桥、苏堤" ✓
- "每天早上喝咖啡" ✓
- "mentioned visiting the lake twice" ✓
## Example (English input → English output)
{
"preferences": [

View File

@@ -538,14 +538,16 @@ class EmotionAnalyticsService:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is not None:
workspace_id = connected_config.get("workspace_id")
config_id = resolve_config_id(config_id, db) if config_id else None
if config_id is not None or workspace_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=(config_id),
config_id=config_id,
workspace_id=workspace_id,
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory

View File

@@ -90,13 +90,17 @@ class ImplicitMemoryService:
# Get user's connected config
connected_config = get_end_user_connected_config(self.end_user_id, self.db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id is None:
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user: {self.end_user_id}")
# Load the memory configuration
# Load the memory configuration with workspace fallback
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(config_id)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
logger.info(f"Loaded memory config {config_id} for end_user: {self.end_user_id}")
return memory_config

View File

@@ -3,6 +3,13 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
TODO: Refactor get_end_user_connected_config
----------------------------------------------
1. Move get_end_user_connected_config to memory_config_service.py
2. Change return type from Dict[str, Any] (with config_id string) to full MemoryConfig model
3. This will eliminate the need for callers to call load_memory_config separately
4. Update all callers to use the new unified function
"""
import json
import os
@@ -283,12 +290,14 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
# Resolve config_id if None using end_user's connected config
# Resolve config_id and workspace_id
workspace_id = None
if config_id is None:
try:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
workspace_id = connected_config.get("workspace_id")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
@@ -299,11 +308,12 @@ class MemoryAgentService:
import time
start_time = time.time()
# Load configuration from database only
# Load configuration from database with workspace fallback
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
@@ -410,12 +420,14 @@ class MemoryAgentService:
start_time = time.time()
ori_message= message
# Resolve config_id if None using end_user's connected config
# Resolve config_id and workspace_id
workspace_id = None
if config_id is None:
try:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
workspace_id = connected_config.get("workspace_id")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
@@ -437,6 +449,7 @@ class MemoryAgentService:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
config_load_time = time.time() - config_load_start
@@ -659,7 +672,13 @@ class MemoryAgentService:
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
@@ -668,18 +687,18 @@ class MemoryAgentService:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
workspace_id: Workspace ID for fallback lookup (optional)
Returns:
Type classification result
"""
logger.info("Classifying message type")
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
@@ -712,9 +731,10 @@ class MemoryAgentService:
"""
if config_id is None:
try:
config_id = get_end_user_connected_config(end_user_id, db)
config_id = config_id.get('memory_config_id')
if config_id is None:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get('memory_config_id')
workspace_id = connected_config.get('workspace_id')
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
@@ -722,6 +742,9 @@ class MemoryAgentService:
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
else:
workspace_id = None
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
@@ -729,6 +752,7 @@ class MemoryAgentService:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
@@ -1158,7 +1182,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
包含 memory_config_id, workspace_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
@@ -1194,18 +1218,17 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
workspace_id=app.workspace_id
)
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
memory_config_id = str(memory_config.config_id) if memory_config else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(app.current_release_id),
"memory_config_id": memory_config_id
"memory_config_id": memory_config_id,
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result
@@ -1213,10 +1236,9 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数)
通过以下流程获取配置
1. 批量查询所有 end_user_id 对应的 app_id
2. 批量获取这些应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
使用与 get_end_user_connected_config 相同的逻辑
1. 优先使用 end_user.memory_config_id
2. 如果没有,回退到工作空间默认配置
Args:
end_user_ids: 终端用户ID列表
@@ -1230,119 +1252,90 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
...
}
"""
from sqlalchemy import select
from app.models.app_release_model import AppRelease
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from app.services.memory_config_service import MemoryConfigService
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
result = {}
# 如果列表为空,直接返回空字典
if not end_user_ids:
return result
# 1. 批量查询所有 end_user 及其 app_id
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建 end_user_id -> app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 创建映射
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
# 记录未找到的用户
found_user_ids = set(user_to_app.keys())
found_user_ids = set(user_data.keys())
missing_user_ids = set(end_user_ids) - found_user_ids
if missing_user_ids:
logger.warning(f"End users not found: {missing_user_ids}")
for user_id in missing_user_ids:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
# 2. 批量获取所有相关应用的最新发布版本
app_ids = list(set(user_to_app.values()))
# 2. 批量获取所有相关应用以获取 workspace_id
app_ids = list(set(data["app_id"] for data in user_data.values()))
if not app_ids:
return result
# 查询所有活跃的发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True))
.order_by(AppRelease.app_id, AppRelease.version.desc())
)
releases = db.scalars(stmt).all()
apps = db.query(App).filter(App.id.in_(app_ids)).all()
app_to_workspace = {app.id: app.workspace_id for app in apps}
# 创建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_to_release = {}
for release in releases:
if release.app_id not in app_to_release:
app_to_release[release.app_id] = release
# 3. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
else:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
workspace_fallback_users.append((end_user_id, workspace_id))
# 3. 收集所有 memory_config_id 并批量查询配置名称
memory_config_ids = []
old_config_ids = [] # 存储旧的整数ID
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
if release:
config = release.config or {}
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
# 判断是否为UUID格式
if len(str(memory_config_id))>=5:
uuid.UUID(str(memory_config_id))
memory_config_ids.append(memory_config_id)
else:
old_config_ids.append(str(memory_config_id))
# 批量查询 memory_config_name
config_id_to_name = {}
# 记录分类结果
if memory_config_ids or old_config_ids:
logger.info(f"Collected {len(memory_config_ids)} UUID config_ids and {len(old_config_ids)} old integer config_ids")
if old_config_ids:
logger.debug(f"Old config IDs: {old_config_ids}")
# 查询新的UUID格式的config_id
if memory_config_ids:
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(memory_config_ids)).all()
config_id_to_name.update({str(mc.config_id): mc.config_name for mc in memory_configs})
# 查询旧的整数ID通过config_id_old字段
if old_config_ids:
old_memory_configs = db.query(MemoryConfig).filter(MemoryConfig.config_id_old.in_(old_config_ids)).all()
# 使用config_id_old作为key这样后面查找时能匹配上
config_id_to_name.update({str(mc.config_id_old): mc.config_name for mc in old_memory_configs})
# 同时也添加config_id作为key方便后续使用
for mc in old_memory_configs:
if mc.config_id_old:
config_id_to_name[str(mc.config_id)] = mc.config_name
logger.info(f"Found {len(old_memory_configs)} configs for old IDs")
# 4. 批量查询直接分配的配置
config_id_to_config = {}
if direct_config_ids:
configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(direct_config_ids)).all()
config_id_to_config = {mc.config_id: mc for mc in configs}
# 4. 构建最终结果
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
# 5. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
workspace_default_configs[workspace_id] = default_config
if not release:
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
continue
# 从 config 中提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 6. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 获取配置名称使用字符串形式的ID进行查找兼容新旧格式
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
memory_config = workspace_default_configs.get(workspace_id)
result[end_user_id] = {
"memory_config_id": memory_config_id,
"memory_config_name": memory_config_name
}
if memory_config:
result[end_user_id] = {
"memory_config_id": str(memory_config.config_id),
"memory_config_name": memory_config.config_name
}
else:
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result

View File

@@ -131,21 +131,27 @@ class MemoryConfigService:
def load_memory_config(
self,
config_id: UUID,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
Load memory configuration from database with optional fallback.
If config_id is provided, attempts to load that config directly.
If config_id is None or not found and workspace_id is provided,
falls back to the workspace's default configuration.
Args:
config_id: Configuration ID (UUID) from database
config_id: Configuration ID (UUID) from database (optional)
workspace_id: Workspace ID for fallback lookup (optional)
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
ConfigurationError: If no valid configuration can be found
"""
start_time = time.time()
@@ -154,34 +160,59 @@ class MemoryConfigService:
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": str(config_id),
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
logger.info(f"Loading memory configuration from database: config_id={config_id}, workspace_id={workspace_id}")
try:
validated_config_id = _validate_config_id(config_id, self.db)
# Step 1: Get config and workspace
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
if workspace_id:
validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
)
elif config_id:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": str(config_id),
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {config_id} not found in database"
f"Configuration not found: config_id={config_id}, workspace_id={workspace_id}"
)
# Get workspace for the config
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
)
memory_config, workspace = result

View File

@@ -62,10 +62,14 @@ def _get_llm_client_for_user(user_id: str):
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id:
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
else:

View File

@@ -1253,10 +1253,22 @@ def long_term_storage_window_task(
# Save to Redis buffer first
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# Load memory config
# Get workspace_id from end_user for fallback
from app.models.app_model import App
from app.models.end_user_model import EndUser
workspace_id = None
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if end_user:
app = db.query(App).filter(App.id == end_user.app_id).first()
if app:
workspace_id = app.workspace_id
# Load memory config with workspace fallback
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="LongTermStorageTask"
)