refactor(memory): restructure memory agent and config management
- Reorganize imports and remove unused dependencies across memory agent controllers - Extract config validation logic into dedicated validators module - Create new memory_config_model and memory_config_schema for configuration management - Implement memory_config_service for centralized config handling - Add embedder_utils module for embedding model utilities - Refactor memory agent service to use new config validation framework - Clean up configuration files (remove config.json, testdata.json, dbrun.json) - Remove deprecated hybrid_chatbot.py and config overrides - Update logging configuration and error handling across memory modules - Consolidate LLM and embedding model validation into validators - Improve code organization and reduce duplication in memory storage services - Enhance type classification and verification tools with better error handling
This commit is contained in:
@@ -4,45 +4,44 @@ Memory Agent Service
|
||||
Handles business logic for memory agent operations including read/write services,
|
||||
health checks, and message type classification.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.db import get_db
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
|
||||
from app.db import get_db
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
@@ -56,6 +55,27 @@ class MemoryAgentService:
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def load_memory_config(self, config_id: int) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
||||
@@ -277,19 +297,20 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 如果 config_id 为 None,使用默认值 "17"
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
@@ -300,20 +321,43 @@ class MemoryAgentService:
|
||||
async with client.session("data_flow") as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
||||
logger.debug("Write graph created successfully")
|
||||
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": message, "config_id": config_id},
|
||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_details
|
||||
)
|
||||
|
||||
raise ValueError(f"Write workflow failed: {error_details}")
|
||||
|
||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
@@ -365,15 +409,15 @@ class MemoryAgentService:
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -387,8 +431,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
@@ -404,45 +446,52 @@ class MemoryAgentService:
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "config_id": config_id},
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Debug: log message type and content preview
|
||||
msg_type = msg.__class__.__name__
|
||||
content_preview = str(msg_content)[:200] if msg_content else "empty"
|
||||
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(msg_content, str):
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(msg_content)
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Debug: log what keys are in parsed
|
||||
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
|
||||
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
@@ -450,34 +499,14 @@ class MemoryAgentService:
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
elif isinstance(msg_content, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in msg_content:
|
||||
intermediate_data = msg_content['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in msg_content:
|
||||
for intermediate_data in msg_content['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
@@ -489,18 +518,57 @@ class MemoryAgentService:
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
try:
|
||||
message = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(message, dict) and message.get('status') != '':
|
||||
summary_result = message.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
if audit_logger:
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
@@ -612,19 +680,25 @@ class MemoryAgentService:
|
||||
else:
|
||||
return output
|
||||
|
||||
async def classify_message_type(self, message: str) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
message: User message to classify
|
||||
config_id: Configuration ID to load LLM model from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
logger.info("Classifying message type")
|
||||
|
||||
status = await status_typle(message)
|
||||
# Load configuration to get LLM model ID
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
@@ -790,7 +864,8 @@ class MemoryAgentService:
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -801,6 +876,7 @@ class MemoryAgentService:
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -862,15 +938,17 @@ class MemoryAgentService:
|
||||
|
||||
await connector.close()
|
||||
|
||||
if not statements:
|
||||
if not statements or not llm_id:
|
||||
result["tags"] = []
|
||||
if not llm_id and statements:
|
||||
logger.warning("llm_id not provided, skipping tag generation")
|
||||
else:
|
||||
# 构建摘要文本
|
||||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client()
|
||||
llm_client = get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
|
||||
264
api/app/services/memory_config_service.py
Normal file
264
api/app/services/memory_config_service.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError as e:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_memory_config(
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method manages its own database session internally.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
from app.db import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
return MemoryConfigService._load_memory_config_with_db(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
service_name=service_name,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _load_memory_config_with_db(
|
||||
config_id: int,
|
||||
db: Session,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""Internal method that loads memory configuration with an existing db session."""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
"Starting memory configuration loading",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
)
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
|
||||
# Resolve LLM model
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Resolve optional rerank model
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
config_name=memory_config.config_name,
|
||||
workspace_id=workspace.id,
|
||||
workspace_name=workspace.name,
|
||||
tenant_id=workspace.tenant_id,
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name=llm_name,
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.info(
|
||||
"Memory configuration loaded successfully",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": validated_config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(config.workspace_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load memory configuration",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.error(f"Failed to load memory configuration {config_id}: {e}")
|
||||
if isinstance(e, (ConfigurationError, ValueError)):
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
@@ -4,39 +4,40 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Load environment variables for Neo4j connector
|
||||
load_dotenv()
|
||||
@@ -48,6 +49,27 @@ class MemoryStorageService:
|
||||
|
||||
def __init__(self):
|
||||
logger.info("MemoryStorageService initialized")
|
||||
|
||||
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryStorageService"
|
||||
)
|
||||
|
||||
async def get_storage_info(self) -> dict:
|
||||
"""
|
||||
@@ -248,7 +270,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
@@ -257,24 +278,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
# 步骤 1: 配置加载和验证(数据库优先)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
@@ -282,12 +291,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
memory_config = MemoryConfigService.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
|
||||
Reference in New Issue
Block a user