dev新增短期记忆功能 (#47)

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能

* dev新增短期记忆功能
This commit is contained in:
lixinyue11
2026-01-07 16:36:11 +08:00
committed by GitHub
parent 5fe8043ff8
commit bcb3d587a1
9 changed files with 765 additions and 45 deletions

View File

@@ -4,6 +4,7 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import datetime
import json
import os
import re
@@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService
@@ -393,7 +395,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message=message
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
@@ -406,15 +408,15 @@ class MemoryAgentService:
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
# Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id)
@@ -430,7 +432,7 @@ class MemoryAgentService:
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
@@ -442,9 +444,9 @@ class MemoryAgentService:
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -452,7 +454,7 @@ class MemoryAgentService:
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
@@ -475,7 +477,7 @@ class MemoryAgentService:
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
@@ -483,7 +485,7 @@ class MemoryAgentService:
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
@@ -496,7 +498,7 @@ class MemoryAgentService:
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
@@ -506,16 +508,16 @@ class MemoryAgentService:
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
@@ -523,7 +525,7 @@ class MemoryAgentService:
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s")
@@ -532,7 +534,7 @@ class MemoryAgentService:
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
@@ -542,7 +544,7 @@ class MemoryAgentService:
break
else:
continue # No text block found
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
@@ -552,15 +554,15 @@ class MemoryAgentService:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
@@ -577,11 +579,11 @@ class MemoryAgentService:
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
@@ -596,7 +598,31 @@ class MemoryAgentService:
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
statements=list(set(statements))
retrieved_content.append({query:statements})
if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]:
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs

View File

@@ -0,0 +1,56 @@
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.repositories.memory_short_repository import ShortTermMemoryRepository
api_logger = get_api_logger()
db=next(get_db())
class ShortService:
def __init__(self, end_user_id):
self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_short_databasets(self):
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = []
for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory
messages = memory.messages
aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or []
api_logger.debug(f"Retrieved content: {retrieved_content}")
retrieval_source = []
for item in retrieved_content:
if isinstance(item, dict):
for key, values in item.items():
retrieval_source.append({"query": key, "retrieval": values})
deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误
deep_expanded['answer'] = aimessages
short_result.append(deep_expanded)
return short_result
def get_short_count(self):
short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count
class LongService:
def __init__(self, end_user_id):
self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = []
for long_memory in long_memories:
if long_memory.retrieved_content:
for memory_item in long_memory.retrieved_content:
if isinstance(memory_item, dict):
for key, values in memory_item.items():
long_result.append({"query": key, "retrieval": values})
return long_result

View File

@@ -1496,8 +1496,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
field_whitelist = {
"Dialogue": ["content", "created_at"],
"Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
}