Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -4,6 +4,7 @@ Memory Agent Service
|
||||
Handles business logic for memory agent operations including read/write services,
|
||||
health checks, and message type classification.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -393,7 +395,7 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
ori_message=message
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
@@ -406,15 +408,15 @@ class MemoryAgentService:
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
# Get group lock to prevent concurrent processing
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
@@ -430,7 +432,7 @@ class MemoryAgentService:
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
@@ -442,9 +444,9 @@ class MemoryAgentService:
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
@@ -452,7 +454,7 @@ class MemoryAgentService:
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
|
||||
async with client.session('data_flow') as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
@@ -475,7 +477,7 @@ class MemoryAgentService:
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
@@ -483,7 +485,7 @@ class MemoryAgentService:
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
@@ -496,7 +498,7 @@ class MemoryAgentService:
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
@@ -506,16 +508,16 @@ class MemoryAgentService:
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
@@ -523,7 +525,7 @@ class MemoryAgentService:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
logger.info(f"Read graph workflow completed in {workflow_duration}s")
|
||||
|
||||
@@ -532,7 +534,7 @@ class MemoryAgentService:
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
@@ -542,7 +544,7 @@ class MemoryAgentService:
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
|
||||
try:
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
@@ -552,15 +554,15 @@ class MemoryAgentService:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
@@ -577,11 +579,11 @@ class MemoryAgentService:
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
@@ -596,7 +598,31 @@ class MemoryAgentService:
|
||||
"has_answer": bool(final_answer)
|
||||
}
|
||||
)
|
||||
|
||||
retrieved_content=[]
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch)!="2":
|
||||
for intermediate in intermediate_outputs:
|
||||
intermediate_type=intermediate['type']
|
||||
if intermediate_type=="search_result":
|
||||
query=intermediate['query']
|
||||
raw_results=intermediate['raw_results']
|
||||
reranked_results=raw_results.get('reranked_results',[])
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
statements=list(set(statements))
|
||||
retrieved_content.append({query:statements})
|
||||
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=group_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=final_answer,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
print("写入成功")
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"answer": final_answer,
|
||||
"intermediate_outputs": intermediate_outputs
|
||||
|
||||
166
api/app/services/memory_perceptual_service.py
Normal file
166
api/app/services/memory_perceptual_service.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
PerceptualMemoryItem,
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
class MemoryPerceptualService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.repository = MemoryPerceptualRepository(db)
|
||||
|
||||
def get_memory_count(self, end_user_id: uuid.UUID) -> Dict[str, Any]:
|
||||
"""Retrieve perceptual memory statistics for a user."""
|
||||
business_logger.info(f"Fetching perceptual memory statistics: end_user_id={end_user_id}")
|
||||
try:
|
||||
total_count = self.repository.get_count_by_user_id(end_user_id=end_user_id)
|
||||
|
||||
vision_count = self.repository.get_count_by_type(end_user_id, PerceptualType.VISION)
|
||||
audio_count = self.repository.get_count_by_type(end_user_id, PerceptualType.AUDIO)
|
||||
text_count = self.repository.get_count_by_type(end_user_id, PerceptualType.TEXT)
|
||||
conversation_count = self.repository.get_count_by_type(end_user_id, PerceptualType.CONVERSATION)
|
||||
|
||||
stats = {
|
||||
"total": total_count,
|
||||
"by_type": {
|
||||
"vision": vision_count,
|
||||
"audio": audio_count,
|
||||
"text": text_count,
|
||||
"conversation": conversation_count
|
||||
}
|
||||
}
|
||||
|
||||
business_logger.info(f"Memory statistics fetched successfully: total={total_count}")
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch memory statistics: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch memory statistics: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
def _get_latest_memory_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: PerceptualType
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Internal helper to retrieve the latest memory by type."""
|
||||
business_logger.info(f"Fetching latest {perceptual_type.name.lower()} memory: end_user_id={end_user_id}")
|
||||
try:
|
||||
memories = self.repository.get_by_type(
|
||||
end_user_id=end_user_id,
|
||||
perceptual_type=perceptual_type,
|
||||
limit=1,
|
||||
offset=0
|
||||
)
|
||||
if not memories:
|
||||
business_logger.info(f"No {perceptual_type.name.lower()} memory found: end_user_id={end_user_id}")
|
||||
return None
|
||||
|
||||
memory = memories[0]
|
||||
meta_data = memory.meta_data or {}
|
||||
modalities = meta_data.get("modalities")
|
||||
content = meta_data.get("content")
|
||||
|
||||
if not modalities:
|
||||
raise BusinessException(f"Modalities not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
|
||||
if not content:
|
||||
raise BusinessException(f"Content not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR)
|
||||
content = Content(**content)
|
||||
match perceptual_type:
|
||||
case PerceptualType.VISION:
|
||||
modal = VideoModal(**modalities)
|
||||
case PerceptualType.AUDIO:
|
||||
modal = AudioModal(**modalities)
|
||||
case PerceptualType.TEXT:
|
||||
modal = TextModal(**modalities)
|
||||
case _:
|
||||
raise BusinessException("Unsupported perceptual type", BizCode.DB_ERROR)
|
||||
detail = modal.model_dump()
|
||||
|
||||
result = {
|
||||
"id": str(memory.id),
|
||||
"file_name": memory.file_name,
|
||||
"file_path": memory.file_path,
|
||||
"storage_type": memory.storage_service,
|
||||
"summary": memory.summary,
|
||||
"keywords": content.keywords,
|
||||
"topic": content.topic,
|
||||
"domain": content.domain,
|
||||
"created_time": memory.created_time.isoformat() if memory.created_time else None,
|
||||
**detail
|
||||
}
|
||||
|
||||
business_logger.info(
|
||||
f"Latest {perceptual_type.name.lower()} memory retrieved successfully: file={memory.file_name}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
BizCode.DB_ERROR)
|
||||
|
||||
def get_latest_visual_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||
return self._get_latest_memory_by_type(end_user_id, PerceptualType.VISION)
|
||||
|
||||
def get_latest_audio_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||
return self._get_latest_memory_by_type(end_user_id, PerceptualType.AUDIO)
|
||||
|
||||
def get_latest_text_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||
return self._get_latest_memory_by_type(end_user_id, PerceptualType.TEXT)
|
||||
|
||||
def get_time_line(self, end_user_id: uuid.UUID, query: PerceptualQuerySchema) -> PerceptualTimelineResponse:
|
||||
"""Retrieve a timeline of perceptual memories for a user."""
|
||||
business_logger.info(f"Fetching perceptual memory timeline: "
|
||||
f"end_user_id={end_user_id}, filter={query.filter}")
|
||||
|
||||
try:
|
||||
if query.page < 1:
|
||||
raise BusinessException("Page number must be greater than 0", BizCode.INVALID_PARAMETER)
|
||||
if query.page_size < 1 or query.page_size > 100:
|
||||
raise BusinessException("Page size must be between 1 and 100", BizCode.INVALID_PARAMETER)
|
||||
|
||||
total_count, memories = self.repository.get_timeline(end_user_id, query)
|
||||
|
||||
memory_items = []
|
||||
for memory in memories:
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
file_path=memory.file_path,
|
||||
file_name=memory.file_name,
|
||||
summary=memory.summary,
|
||||
created_time=memory.created_time,
|
||||
storage_type=FileStorageType(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
|
||||
timeline_response = PerceptualTimelineResponse(
|
||||
total=total_count,
|
||||
page=query.page,
|
||||
page_size=query.page_size,
|
||||
total_pages=(total_count + query.page_size - 1) // query.page_size,
|
||||
memories=memory_items
|
||||
)
|
||||
|
||||
business_logger.info(f"Perceptual memory timeline retrieved successfully: "
|
||||
f"total={total_count}, returned={len(memories)}")
|
||||
return timeline_response
|
||||
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
56
api/app/services/memory_short_service.py
Normal file
56
api/app/services/memory_short_service.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
db=next(get_db())
|
||||
class ShortService:
|
||||
def __init__(self, end_user_id):
|
||||
self.short_repo = ShortTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
|
||||
def get_short_databasets(self):
|
||||
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
||||
short_result = []
|
||||
for memory in short_memories:
|
||||
deep_expanded = {} # Create a new dictionary for each memory
|
||||
messages = memory.messages
|
||||
aimessages = memory.aimessages
|
||||
retrieved_content = memory.retrieved_content or []
|
||||
|
||||
api_logger.debug(f"Retrieved content: {retrieved_content}")
|
||||
|
||||
retrieval_source = []
|
||||
for item in retrieved_content:
|
||||
if isinstance(item, dict):
|
||||
for key, values in item.items():
|
||||
retrieval_source.append({"query": key, "retrieval": values})
|
||||
|
||||
deep_expanded['retrieval'] = retrieval_source
|
||||
deep_expanded['message'] = messages # 修正拼写错误
|
||||
deep_expanded['answer'] = aimessages
|
||||
short_result.append(deep_expanded)
|
||||
return short_result
|
||||
def get_short_count(self):
|
||||
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
||||
return short_count
|
||||
|
||||
class LongService:
|
||||
def __init__(self, end_user_id):
|
||||
self.long_repo = LongTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
def get_long_databasets(self):
|
||||
# 获取长期记忆数据
|
||||
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
||||
|
||||
long_result = []
|
||||
for long_memory in long_memories:
|
||||
if long_memory.retrieved_content:
|
||||
for memory_item in long_memory.retrieved_content:
|
||||
if isinstance(memory_item, dict):
|
||||
for key, values in memory_item.items():
|
||||
long_result.append({"query": key, "retrieval": values})
|
||||
return long_result
|
||||
@@ -166,6 +166,8 @@ class PromptOptimizerService:
|
||||
model_config = self.get_model_config(tenant_id, model_id)
|
||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||
|
||||
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||
|
||||
# Create LLM instance
|
||||
api_config: ModelApiKey = model_config.api_keys[0]
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
@@ -203,7 +205,6 @@ class PromptOptimizerService:
|
||||
|
||||
messages.extend(session_history[:-1]) # last message is current message
|
||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||
logger.info(f"Prompt optimization message: {messages}")
|
||||
buffer = ""
|
||||
prompt_started = False
|
||||
prompt_finished = False
|
||||
@@ -250,6 +251,7 @@ class PromptOptimizerService:
|
||||
content=desc
|
||||
)
|
||||
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||
logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}")
|
||||
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
|
||||
field_whitelist = {
|
||||
"Dialogue": ["content", "created_at"],
|
||||
"Chunk": ["content", "created_at"],
|
||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"],
|
||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"],
|
||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user