feat(memory): add workspace_id fallback support for memory config resolution

- Add workspace_id fallback parameter to memory config loading across all services
- Update hot_memory_tags.py to pass workspace_id when resolving memory configuration
- Enhance emotion_analytics_service.py to support workspace_id as fallback for config resolution
- Improve implicit_memory_service.py with workspace_id fallback in config loading
- Update memory_agent_service.py to handle workspace_id resolution and add refactoring TODO
- Enhance preference_analysis.jinja2 prompt with critical guidance on supporting_evidence extraction
- Add validation to check both config_id and workspace_id before raising configuration errors
- Improve error handling and logging for memory configuration resolution across services
- This enables more flexible memory configuration resolution when config_id is unavailable
This commit is contained in:
Ke Sun
2026-02-06 14:48:58 +08:00
parent 7a78f15a90
commit 5c10f11681
8 changed files with 195 additions and 134 deletions

View File

@@ -538,14 +538,16 @@ class EmotionAnalyticsService:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is not None:
workspace_id = connected_config.get("workspace_id")
config_id = resolve_config_id(config_id, db) if config_id else None
if config_id is not None or workspace_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=(config_id),
config_id=config_id,
workspace_id=workspace_id,
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory

View File

@@ -90,13 +90,17 @@ class ImplicitMemoryService:
# Get user's connected config
connected_config = get_end_user_connected_config(self.end_user_id, self.db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id is None:
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user: {self.end_user_id}")
# Load the memory configuration
# Load the memory configuration with workspace fallback
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(config_id)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
logger.info(f"Loaded memory config {config_id} for end_user: {self.end_user_id}")
return memory_config

View File

@@ -3,6 +3,13 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
TODO: Refactor get_end_user_connected_config
----------------------------------------------
1. Move get_end_user_connected_config to memory_config_service.py
2. Change return type from Dict[str, Any] (with config_id string) to full MemoryConfig model
3. This will eliminate the need for callers to call load_memory_config separately
4. Update all callers to use the new unified function
"""
import json
import os
@@ -283,12 +290,14 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
# Resolve config_id if None using end_user's connected config
# Resolve config_id and workspace_id
workspace_id = None
if config_id is None:
try:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
workspace_id = connected_config.get("workspace_id")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
@@ -299,11 +308,12 @@ class MemoryAgentService:
import time
start_time = time.time()
# Load configuration from database only
# Load configuration from database with workspace fallback
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
@@ -410,12 +420,14 @@ class MemoryAgentService:
start_time = time.time()
ori_message= message
# Resolve config_id if None using end_user's connected config
# Resolve config_id and workspace_id
workspace_id = None
if config_id is None:
try:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
workspace_id = connected_config.get("workspace_id")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
@@ -437,6 +449,7 @@ class MemoryAgentService:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
config_load_time = time.time() - config_load_start
@@ -659,7 +672,13 @@ class MemoryAgentService:
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
@@ -668,18 +687,18 @@ class MemoryAgentService:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
workspace_id: Workspace ID for fallback lookup (optional)
Returns:
Type classification result
"""
logger.info("Classifying message type")
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
@@ -712,9 +731,10 @@ class MemoryAgentService:
"""
if config_id is None:
try:
config_id = get_end_user_connected_config(end_user_id, db)
config_id = config_id.get('memory_config_id')
if config_id is None:
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get('memory_config_id')
workspace_id = connected_config.get('workspace_id')
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
@@ -722,6 +742,9 @@ class MemoryAgentService:
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
else:
workspace_id = None
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
@@ -729,6 +752,7 @@ class MemoryAgentService:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
@@ -1158,7 +1182,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
包含 memory_config_id, workspace_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
@@ -1194,18 +1218,17 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
workspace_id=app.workspace_id
)
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
memory_config_id = str(memory_config.config_id) if memory_config else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(app.current_release_id),
"memory_config_id": memory_config_id
"memory_config_id": memory_config_id,
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result
@@ -1213,10 +1236,9 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数)
通过以下流程获取配置
1. 批量查询所有 end_user_id 对应的 app_id
2. 批量获取这些应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
使用与 get_end_user_connected_config 相同的逻辑
1. 优先使用 end_user.memory_config_id
2. 如果没有,回退到工作空间默认配置
Args:
end_user_ids: 终端用户ID列表
@@ -1230,119 +1252,90 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
...
}
"""
from sqlalchemy import select
from app.models.app_release_model import AppRelease
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from app.services.memory_config_service import MemoryConfigService
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
result = {}
# 如果列表为空,直接返回空字典
if not end_user_ids:
return result
# 1. 批量查询所有 end_user 及其 app_id
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建 end_user_id -> app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 创建映射
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
# 记录未找到的用户
found_user_ids = set(user_to_app.keys())
found_user_ids = set(user_data.keys())
missing_user_ids = set(end_user_ids) - found_user_ids
if missing_user_ids:
logger.warning(f"End users not found: {missing_user_ids}")
for user_id in missing_user_ids:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
# 2. 批量获取所有相关应用的最新发布版本
app_ids = list(set(user_to_app.values()))
# 2. 批量获取所有相关应用以获取 workspace_id
app_ids = list(set(data["app_id"] for data in user_data.values()))
if not app_ids:
return result
# 查询所有活跃的发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True))
.order_by(AppRelease.app_id, AppRelease.version.desc())
)
releases = db.scalars(stmt).all()
apps = db.query(App).filter(App.id.in_(app_ids)).all()
app_to_workspace = {app.id: app.workspace_id for app in apps}
# 创建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_to_release = {}
for release in releases:
if release.app_id not in app_to_release:
app_to_release[release.app_id] = release
# 3. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
else:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
workspace_fallback_users.append((end_user_id, workspace_id))
# 3. 收集所有 memory_config_id 并批量查询配置名称
memory_config_ids = []
old_config_ids = [] # 存储旧的整数ID
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
if release:
config = release.config or {}
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
# 判断是否为UUID格式
if len(str(memory_config_id))>=5:
uuid.UUID(str(memory_config_id))
memory_config_ids.append(memory_config_id)
else:
old_config_ids.append(str(memory_config_id))
# 批量查询 memory_config_name
config_id_to_name = {}
# 记录分类结果
if memory_config_ids or old_config_ids:
logger.info(f"Collected {len(memory_config_ids)} UUID config_ids and {len(old_config_ids)} old integer config_ids")
if old_config_ids:
logger.debug(f"Old config IDs: {old_config_ids}")
# 查询新的UUID格式的config_id
if memory_config_ids:
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(memory_config_ids)).all()
config_id_to_name.update({str(mc.config_id): mc.config_name for mc in memory_configs})
# 查询旧的整数ID通过config_id_old字段
if old_config_ids:
old_memory_configs = db.query(MemoryConfig).filter(MemoryConfig.config_id_old.in_(old_config_ids)).all()
# 使用config_id_old作为key这样后面查找时能匹配上
config_id_to_name.update({str(mc.config_id_old): mc.config_name for mc in old_memory_configs})
# 同时也添加config_id作为key方便后续使用
for mc in old_memory_configs:
if mc.config_id_old:
config_id_to_name[str(mc.config_id)] = mc.config_name
logger.info(f"Found {len(old_memory_configs)} configs for old IDs")
# 4. 批量查询直接分配的配置
config_id_to_config = {}
if direct_config_ids:
configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(direct_config_ids)).all()
config_id_to_config = {mc.config_id: mc for mc in configs}
# 4. 构建最终结果
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
# 5. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
workspace_default_configs[workspace_id] = default_config
if not release:
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
continue
# 从 config 中提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 6. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 获取配置名称使用字符串形式的ID进行查找兼容新旧格式
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
memory_config = workspace_default_configs.get(workspace_id)
result[end_user_id] = {
"memory_config_id": memory_config_id,
"memory_config_name": memory_config_name
}
if memory_config:
result[end_user_id] = {
"memory_config_id": str(memory_config.config_id),
"memory_config_name": memory_config.config_name
}
else:
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result

View File

@@ -131,21 +131,27 @@ class MemoryConfigService:
def load_memory_config(
self,
config_id: UUID,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
Load memory configuration from database with optional fallback.
If config_id is provided, attempts to load that config directly.
If config_id is None or not found and workspace_id is provided,
falls back to the workspace's default configuration.
Args:
config_id: Configuration ID (UUID) from database
config_id: Configuration ID (UUID) from database (optional)
workspace_id: Workspace ID for fallback lookup (optional)
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
ConfigurationError: If no valid configuration can be found
"""
start_time = time.time()
@@ -154,34 +160,59 @@ class MemoryConfigService:
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": str(config_id),
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
logger.info(f"Loading memory configuration from database: config_id={config_id}, workspace_id={workspace_id}")
try:
validated_config_id = _validate_config_id(config_id, self.db)
# Step 1: Get config and workspace
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
if workspace_id:
validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
)
elif config_id:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": str(config_id),
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {config_id} not found in database"
f"Configuration not found: config_id={config_id}, workspace_id={workspace_id}"
)
# Get workspace for the config
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
)
memory_config, workspace = result

View File

@@ -62,10 +62,14 @@ def _get_llm_client_for_user(user_id: str):
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id:
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
else: