refactor(memory): restructure memory system and improve configuration management
- Remove deprecated main.py entry point from memory module - Reorganize imports across controllers and services for consistency - Update emotion controller to pass db session instead of config_id to services - Enhance memory agent controller with db session parameter for status_type and user_profile endpoints - Refactor memory agent service to accept db parameter in classify_message_type method - Improve configuration handling in celery_app by removing automatic database reload - Update all memory-related services to use centralized config management - Standardize import ordering and remove unused imports across 50+ files - Add pilot_run_service for new pilot execution workflow - Refactor extraction engine, reflection engine, and search services for better modularity - Update LLM utilities and embedder configuration for improved flexibility - Enhance type classifier and verification tools with better error handling - Improve memory evaluation modules (LOCOMO, LongMemEval, MemSciQA) with consistent patterns
This commit is contained in:
@@ -3,26 +3,26 @@
|
||||
|
||||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.models import AgentConfig, ModelConfig, ModelApiKey
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.services.langchain_tool_server import Search
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
@@ -713,9 +719,9 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当指定的会话不存在时
|
||||
"""
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.models import Conversation as ConversationModel
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
|
||||
@@ -7,14 +7,15 @@ Classes:
|
||||
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import statistics
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.neo4j.emotion_repository import EmotionRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
|
||||
async def generate_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
config_id: Optional[int] = None
|
||||
db: Session,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成个性化情绪建议
|
||||
|
||||
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
config_id: 配置ID(可选,用于从数据库加载LLM配置)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 包含个性化建议的响应:
|
||||
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
|
||||
- suggestions: 建议列表(3-5条)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
||||
|
||||
# 1. 如果提供了 config_id,从数据库加载配置
|
||||
if config_id is not None:
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
|
||||
# 1. 从 end_user_id 获取关联的 memory_config_id
|
||||
llm_client = None
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is not None:
|
||||
from app.services.memory_config_service import (
|
||||
MemoryConfigService,
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.client_factory import MemoryClientFactory
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
||||
|
||||
# 2. 获取情绪健康数据
|
||||
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
||||
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
|
||||
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
||||
|
||||
# 7. 调用LLM生成建议(使用配置中的LLM)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
llm_client = get_llm_client()
|
||||
if llm_client is None:
|
||||
# 无法获取配置时,抛出错误而不是使用默认配置
|
||||
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
||||
|
||||
# 将 prompt 转换为 messages 格式
|
||||
messages = [
|
||||
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
|
||||
Returns:
|
||||
str: LLM prompt
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_suggestions_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_suggestions_prompt(
|
||||
health_data=health_data,
|
||||
|
||||
@@ -9,10 +9,12 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.data_config_model import DataConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +52,9 @@ class EmotionExtractionService:
|
||||
"""
|
||||
if self.llm_client is None or model_id:
|
||||
effective_model_id = model_id or self.llm_id
|
||||
self.llm_client = get_llm_client(effective_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(effective_model_id)
|
||||
return self.llm_client
|
||||
|
||||
async def extract_emotion(
|
||||
@@ -142,7 +146,9 @@ class EmotionExtractionService:
|
||||
Returns:
|
||||
Formatted prompt string for LLM
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_extraction_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_extraction_prompt(
|
||||
statement=statement,
|
||||
|
||||
@@ -21,8 +21,8 @@ from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
@@ -45,8 +45,7 @@ config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
@@ -55,27 +54,6 @@ class MemoryAgentService:
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def load_memory_config(self, config_id: int) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
||||
@@ -277,14 +255,17 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str, db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Write operation result status
|
||||
@@ -292,14 +273,24 @@ class MemoryAgentService:
|
||||
Raises:
|
||||
ValueError: If config loading fails or write operation fails
|
||||
"""
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
@@ -366,6 +357,7 @@ class MemoryAgentService:
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: str,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> Dict:
|
||||
@@ -378,11 +370,14 @@ class MemoryAgentService:
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: User message
|
||||
history: Conversation history
|
||||
search_switch: Search mode switch
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Dict with 'answer' and 'intermediate_outputs' keys
|
||||
@@ -394,8 +389,13 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
@@ -411,7 +411,11 @@ class MemoryAgentService:
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
@@ -696,7 +700,11 @@ class MemoryAgentService:
|
||||
logger.info("Classifying message type")
|
||||
|
||||
# Load configuration to get LLM model ID
|
||||
memory_config = self.load_memory_config(config_id)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
@@ -865,7 +873,8 @@ class MemoryAgentService:
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -877,6 +886,7 @@ class MemoryAgentService:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
- db: 数据库会话(可选)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -893,7 +903,7 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 根据 end_user_id 获取 end_user_name
|
||||
try:
|
||||
if end_user_id:
|
||||
if end_user_id and db:
|
||||
from app.repositories import end_user_repository
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
@@ -948,7 +958,9 @@ class MemoryAgentService:
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client(llm_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
@@ -1110,4 +1122,69 @@ class MemoryAgentService:
|
||||
# "msg": "解析失败",
|
||||
# "error_code": "DOC_PARSE_ERROR",
|
||||
# "data": {"error": str(e)}
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
logger.warning(f"End user not found: {end_user_id}")
|
||||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||||
|
||||
app_id = end_user.app_id
|
||||
logger.debug(f"Found end_user app_id: {app_id}")
|
||||
|
||||
# 2. 获取该应用的最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||||
.order_by(AppRelease.version.desc())
|
||||
)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
if not latest_release:
|
||||
logger.warning(f"No active release found for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||
return result
|
||||
@@ -3,7 +3,6 @@ Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
Database session management is handled internally.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -57,7 +56,7 @@ def _validate_config_id(config_id):
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
@@ -77,19 +76,29 @@ class MemoryConfigService:
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
Database session management is handled internally.
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method manages its own database session internally.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
@@ -100,27 +109,6 @@ class MemoryConfigService:
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
from app.db import get_db
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
return MemoryConfigService._load_memory_config_with_db(
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
service_name=service_name,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _load_memory_config_with_db(
|
||||
config_id: int,
|
||||
db: Session,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""Internal method that loads memory configuration with an existing db session."""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
@@ -137,7 +125,7 @@ class MemoryConfigService:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -160,7 +148,7 @@ class MemoryConfigService:
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
@@ -169,7 +157,7 @@ class MemoryConfigService:
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
@@ -183,7 +171,7 @@ class MemoryConfigService:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
@@ -194,7 +182,7 @@ class MemoryConfigService:
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
db,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
@@ -220,6 +208,25 @@ class MemoryConfigService:
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
@@ -262,3 +269,131 @@ class MemoryConfigService:
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict:
|
||||
"""Get LLM model configuration by ID.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
}
|
||||
|
||||
def get_embedder_config(self, embedding_id: str) -> dict:
|
||||
"""Get embedding model configuration by ID.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_config(memory_config: MemoryConfig):
|
||||
"""Build ExtractionPipelineConfig from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing all pipeline settings.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings.
|
||||
"""
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
|
||||
dedup_config = DedupConfig(
|
||||
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
|
||||
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
|
||||
fuzzy_name_threshold_strict=memory_config.t_name_strict,
|
||||
fuzzy_type_threshold_strict=memory_config.t_type_strict,
|
||||
fuzzy_overall_threshold=memory_config.t_overall,
|
||||
)
|
||||
|
||||
stmt_config = StatementExtractionConfig(
|
||||
statement_granularity=memory_config.statement_granularity,
|
||||
include_dialogue_context=memory_config.include_dialogue_context,
|
||||
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
|
||||
)
|
||||
|
||||
forget_config = ForgettingEngineConfig(
|
||||
offset=memory_config.offset,
|
||||
lambda_time=memory_config.lambda_time,
|
||||
lambda_mem=memory_config.lambda_mem,
|
||||
)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pruning_config(memory_config: MemoryConfig) -> dict:
|
||||
"""Retrieve semantic pruning config from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing pruning settings.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
|
||||
@@ -49,27 +49,6 @@ class MemoryStorageService:
|
||||
|
||||
def __init__(self):
|
||||
logger.info("MemoryStorageService initialized")
|
||||
|
||||
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
This method delegates to the centralized MemoryConfigService to avoid
|
||||
code duplication with other services.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
return MemoryConfigService.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryStorageService"
|
||||
)
|
||||
|
||||
async def get_storage_info(self) -> dict:
|
||||
"""
|
||||
@@ -293,7 +272,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
memory_config = MemoryConfigService.load_memory_config(
|
||||
config_service = MemoryConfigService(self.db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
@@ -320,13 +300,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.services.pilot_run_service import run_pilot_extraction
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
await run_pilot_extraction(
|
||||
memory_config=memory_config,
|
||||
dialogue_text=dialogue_text,
|
||||
db=self.db,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
|
||||
219
api/app/services/pilot_run_service.py
Normal file
219
api/app/services/pilot_run_service.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Pilot Run Service - 试运行服务
|
||||
|
||||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
db: Session,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行试运行模式的知识提取流水线。
|
||||
|
||||
Args:
|
||||
memory_config: 从数据库加载的内存配置对象
|
||||
dialogue_text: 输入的对话文本
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典
|
||||
"""
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
client_factory = MemoryClientFactory(db)
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
logger.info("Parsing dialogue text...")
|
||||
step_start = time.time()
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches
|
||||
if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=str(memory_config.workspace_id),
|
||||
user_id=str(memory_config.tenant_id),
|
||||
apply_id=str(memory_config.config_id),
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|
||||
Reference in New Issue
Block a user