config_config替换成memory_config
This commit is contained in:
@@ -8,6 +8,8 @@ Classes:
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
@@ -37,7 +39,7 @@ class EmotionConfigService:
|
||||
self.db = db
|
||||
logger.info("情绪配置服务初始化完成")
|
||||
|
||||
def get_emotion_config(self, config_id: int) -> Dict[str, Any]:
|
||||
def get_emotion_config(self, config_id: UUID) -> Dict[str, Any]:
|
||||
"""获取情绪引擎配置
|
||||
|
||||
查询指定配置ID的情绪相关配置字段。
|
||||
@@ -144,7 +146,7 @@ class EmotionConfigService:
|
||||
|
||||
def update_emotion_config(
|
||||
self,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
config_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""更新情绪引擎配置
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
@@ -266,7 +267,7 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -319,85 +320,52 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
print(massages)
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
|
||||
# try:
|
||||
# if storage_type == "rag":
|
||||
# # For RAG storage, convert messages to single string
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
# return result
|
||||
# else:
|
||||
# async with make_write_graph() as graph:
|
||||
# config = {"configurable": {"thread_id": end_user_id}}
|
||||
# # Convert structured messages to LangChain messages
|
||||
# langchain_messages = []
|
||||
# for msg in messages:
|
||||
# if msg['role'] == 'user':
|
||||
# langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
# elif msg['role'] == 'assistant':
|
||||
# langchain_messages.append(AIMessage(content=msg['content']))
|
||||
#
|
||||
# # 初始状态 - 包含所有必要字段
|
||||
# initial_state = {
|
||||
# "messages": langchain_messages,
|
||||
# "end_user_id": end_user_id,
|
||||
# "memory_config": memory_config
|
||||
# }
|
||||
#
|
||||
# # 获取节点更新信息
|
||||
# async for update_event in graph.astream(
|
||||
# initial_state,
|
||||
# stream_mode="updates",
|
||||
# config=config
|
||||
# ):
|
||||
# for node_name, node_data in update_event.items():
|
||||
# if 'save_neo4j' == node_name:
|
||||
# massages = node_data
|
||||
# massagesstatus = massages.get('write_result')['status']
|
||||
# contents = massages.get('write_result')
|
||||
# # Convert messages back to string for logging
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
# except Exception as e:
|
||||
# # Ensure proper error handling and logging
|
||||
# error_msg = f"Write operation failed: {str(e)}"
|
||||
# logger.error(error_msg)
|
||||
# if audit_logger:
|
||||
# duration = time.time() - start_time
|
||||
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
# raise ValueError(error_msg)
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
@@ -408,7 +376,7 @@ class MemoryAgentService:
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[str],
|
||||
config_id: Optional[UUID],
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
@@ -685,7 +653,7 @@ class MemoryAgentService:
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
@@ -716,7 +684,7 @@ class MemoryAgentService:
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
config_id: UUID,
|
||||
db: Session
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
@@ -23,53 +23,12 @@ from app.schemas.memory_config_schema import (
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
@@ -93,14 +52,14 @@ class MemoryConfigService:
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
config_id: Configuration ID (UUID) from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
Returns:
|
||||
@@ -116,18 +75,34 @@ class MemoryConfigService:
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"config_id": str(config_id),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
# Validate config_id is UUID
|
||||
if not isinstance(config_id, UUID):
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
config_id = UUID(config_id)
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid UUID format for config_id: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
else:
|
||||
raise InvalidConfigError(
|
||||
f"config_id must be UUID or valid UUID string, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
# Step 1: Get config and workspace
|
||||
db_query_start = time.time()
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
if not result:
|
||||
@@ -136,14 +111,14 @@ class MemoryConfigService:
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"config_id": str(config_id),
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
f"Configuration {config_id} not found in database"
|
||||
)
|
||||
|
||||
memory_config, workspace = result
|
||||
@@ -151,7 +126,7 @@ class MemoryConfigService:
|
||||
# Step 2: Validate embedding model (returns both UUID and name)
|
||||
embed_start = time.time()
|
||||
embedding_uuid, embedding_name = validate_embedding_model(
|
||||
validated_config_id,
|
||||
config_id,
|
||||
memory_config.embedding_id,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
@@ -168,7 +143,7 @@ class MemoryConfigService:
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
llm_time = time.time() - llm_start
|
||||
@@ -185,7 +160,7 @@ class MemoryConfigService:
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
rerank_time = time.time() - rerank_start
|
||||
@@ -243,7 +218,7 @@ class MemoryConfigService:
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": validated_config_id,
|
||||
"config_id": str(config_id),
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(config.workspace_id),
|
||||
"load_result": "success",
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -87,7 +88,7 @@ class MemoryForgetService:
|
||||
async def _get_forgetting_components(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
|
||||
"""
|
||||
获取遗忘引擎组件(计算器、策略、调度器)
|
||||
@@ -294,7 +295,7 @@ class MemoryForgetService:
|
||||
end_user_id: str,
|
||||
max_merge_batch_size: Optional[int] = None,
|
||||
min_days_since_access: Optional[int] = None,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
@@ -389,7 +390,7 @@ class MemoryForgetService:
|
||||
def read_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int
|
||||
config_id: UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
@@ -416,7 +417,7 @@ class MemoryForgetService:
|
||||
def update_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
update_fields: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -466,7 +467,7 @@ class MemoryForgetService:
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
@@ -677,7 +678,7 @@ class MemoryForgetService:
|
||||
db: Session,
|
||||
importance_score: float,
|
||||
days: int,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
Reference in New Issue
Block a user