refactor(memory): restructure memory system and improve configuration management

- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
Ke Sun
2025-12-23 17:17:04 +08:00
parent 258b88276f
commit 283c64a358
58 changed files with 2171 additions and 1797 deletions

View File

@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService:
"""Service for memory agent operations"""
@@ -55,27 +54,6 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def load_memory_config(self, config_id: int) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
@@ -277,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Write operation result status
@@ -292,14 +273,24 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
import time
start_time = time.time()
# Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -366,6 +357,7 @@ class MemoryAgentService:
history: List[Dict],
search_switch: str,
config_id: str,
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
@@ -378,11 +370,14 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
@@ -394,8 +389,13 @@ class MemoryAgentService:
import time
start_time = time.time()
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -411,7 +411,11 @@ class MemoryAgentService:
with group_lock:
# Step 1: Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -696,7 +700,11 @@ class MemoryAgentService:
logger.info("Classifying message type")
# Load configuration to get LLM model ID
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
@@ -865,7 +873,8 @@ class MemoryAgentService:
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -877,6 +886,7 @@ class MemoryAgentService:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
@@ -893,7 +903,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -948,7 +958,9 @@ class MemoryAgentService:
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client(llm_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
@@ -1110,4 +1122,69 @@ class MemoryAgentService:
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result