refactor(memory): restructure memory system and improve configuration management

- Remove deprecated main.py entry point from memory module
- Reorganize imports across controllers and services for consistency
- Update emotion controller to pass db session instead of config_id to services
- Enhance memory agent controller with db session parameter for status_type and user_profile endpoints
- Refactor memory agent service to accept db parameter in classify_message_type method
- Improve configuration handling in celery_app by removing automatic database reload
- Update all memory-related services to use centralized config management
- Standardize import ordering and remove unused imports across 50+ files
- Add pilot_run_service for new pilot execution workflow
- Refactor extraction engine, reflection engine, and search services for better modularity
- Update LLM utilities and embedder configuration for improved flexibility
- Enhance type classifier and verification tools with better error handling
- Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
Ke Sun
2025-12-23 17:17:04 +08:00
parent 258b88276f
commit 283c64a358
58 changed files with 2171 additions and 1797 deletions

View File

@@ -3,26 +3,26 @@
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import time
import uuid
import json
import asyncio
import datetime
from typing import Dict, Any, Optional, List, AsyncGenerator
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy import select
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.models import AgentConfig, ModelConfig, ModelApiKey
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.core.rag.nlp.search import knowledge_retrieval
from app.services.langchain_tool_server import Search
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
from app.db import get_db
db = next(get_db())
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
)
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -713,9 +719,9 @@ class DraftRunService:
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.services.conversation_service import ConversationService
from app.schemas.conversation_schema import ConversationCreate
from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)

View File

@@ -7,14 +7,15 @@ Classes:
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
"""
from typing import Dict, Any, Optional, List
import statistics
import json
from pydantic import BaseModel, Field
import statistics
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.repositories.neo4j.emotion_repository import EmotionRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_business_logger()
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
async def generate_emotion_suggestions(
self,
end_user_id: str,
config_id: Optional[int] = None
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
Args:
end_user_id: 宿主ID用户组ID
config_id: 配置ID可选用于从数据库加载LLM配置
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
- suggestions: 建议列表3-5条
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 如果提供了 config_id从数据库加载配置
if config_id is not None:
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.client_factory import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
from app.core.memory.utils.llm.llm_utils import get_llm_client
llm_client = get_llm_client()
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,

View File

@@ -9,10 +9,12 @@ Classes:
import logging
from typing import Optional
from app.core.memory.models.emotion_models import EmotionExtraction
from app.models.data_config_model import DataConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.data_config_model import DataConfig
logger = logging.getLogger(__name__)
@@ -50,7 +52,9 @@ class EmotionExtractionService:
"""
if self.llm_client is None or model_id:
effective_model_id = model_id or self.llm_id
self.llm_client = get_llm_client(effective_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(effective_model_id)
return self.llm_client
async def extract_emotion(
@@ -142,7 +146,9 @@ class EmotionExtractionService:
Returns:
Formatted prompt string for LLM
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_extraction_prompt,
)
prompt = await render_emotion_extraction_prompt(
statement=statement,

View File

@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService:
"""Service for memory agent operations"""
@@ -55,27 +54,6 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def load_memory_config(self, config_id: int) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
@@ -277,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Write operation result status
@@ -292,14 +273,24 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
import time
start_time = time.time()
# Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -366,6 +357,7 @@ class MemoryAgentService:
history: List[Dict],
search_switch: str,
config_id: str,
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
@@ -378,11 +370,14 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
@@ -394,8 +389,13 @@ class MemoryAgentService:
import time
start_time = time.time()
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -411,7 +411,11 @@ class MemoryAgentService:
with group_lock:
# Step 1: Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
@@ -696,7 +700,11 @@ class MemoryAgentService:
logger.info("Classifying message type")
# Load configuration to get LLM model ID
memory_config = self.load_memory_config(config_id)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
@@ -865,7 +873,8 @@ class MemoryAgentService:
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -877,6 +886,7 @@ class MemoryAgentService:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
@@ -893,7 +903,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -948,7 +958,9 @@ class MemoryAgentService:
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client(llm_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
@@ -1110,4 +1122,69 @@ class MemoryAgentService:
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -3,7 +3,6 @@ Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
Database session management is handled internally.
"""
import time
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
invalid_value=config_id,
)
return parsed_id
except ValueError as e:
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
@@ -77,19 +76,29 @@ class MemoryConfigService:
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Database session management is handled internally.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
@staticmethod
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method manages its own database session internally.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
@@ -100,27 +109,6 @@ class MemoryConfigService:
Raises:
ConfigurationError: If validation fails
"""
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
try:
return MemoryConfigService._load_memory_config_with_db(
config_id=config_id,
db=db,
service_name=service_name,
)
finally:
db.close()
@staticmethod
def _load_memory_config_with_db(
config_id: int,
db: Session,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""Internal method that loads memory configuration with an existing db session."""
start_time = time.time()
config_logger.info(
@@ -137,7 +125,7 @@ class MemoryConfigService:
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -160,7 +148,7 @@ class MemoryConfigService:
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
db,
self.db,
workspace.tenant_id,
workspace.id,
)
@@ -169,7 +157,7 @@ class MemoryConfigService:
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
db,
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
@@ -183,7 +171,7 @@ class MemoryConfigService:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
db,
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
@@ -194,7 +182,7 @@ class MemoryConfigService:
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
db,
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
@@ -220,6 +208,25 @@ class MemoryConfigService:
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
@@ -262,3 +269,131 @@ class MemoryConfigService:
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}

View File

@@ -49,27 +49,6 @@ class MemoryStorageService:
def __init__(self):
logger.info("MemoryStorageService initialized")
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryStorageService"
)
async def get_storage_info(self) -> dict:
"""
@@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# Load configuration from database only using centralized manager
try:
memory_config = MemoryConfigService.load_memory_config(
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(
config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream"
)
@@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
db=self.db,
progress_callback=progress_callback,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")

View File

@@ -0,0 +1,219 @@
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")