Fix/memory bug fix (#162)
* 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察) * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段
This commit is contained in:
@@ -16,6 +16,7 @@ import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.memory_episodic_schema import EmotionType
|
||||
from app.services.memory_base_service import Translation_English
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +25,7 @@ class MemoryEntityService:
|
||||
self.id = id
|
||||
self.table = table
|
||||
self.connector = Neo4jConnector()
|
||||
async def get_timeline_memories_server(self):
|
||||
async def get_timeline_memories_server(self,model_id, language_type):
|
||||
"""
|
||||
获取时间线记忆数据
|
||||
|
||||
@@ -48,10 +49,10 @@ class MemoryEntityService:
|
||||
logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
# 根据表类型选择查询
|
||||
if self.table == 'Statement':
|
||||
if self.table == 'Statement':
|
||||
# Statement只需要输入ID,使用简化查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id)
|
||||
elif self.table == 'ExtractedEntity':
|
||||
elif self.table == 'ExtractedEntity':
|
||||
# ExtractedEntity类型查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id)
|
||||
else:
|
||||
@@ -62,7 +63,7 @@ class MemoryEntityService:
|
||||
logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}")
|
||||
|
||||
# 处理查询结果
|
||||
timeline_data = self._process_timeline_results(results)
|
||||
timeline_data =await self._process_timeline_results(results, model_id, language_type)
|
||||
|
||||
logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条")
|
||||
|
||||
@@ -71,12 +72,14 @@ class MemoryEntityService:
|
||||
except Exception as e:
|
||||
logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True)
|
||||
return str(e)
|
||||
def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理时间线查询结果
|
||||
|
||||
Args:
|
||||
results: Neo4j查询结果
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
处理后的时间线数据字典
|
||||
@@ -104,19 +107,19 @@ class MemoryEntityService:
|
||||
# 处理MemorySummary
|
||||
summary = data.get('MemorySummary')
|
||||
if summary is not None:
|
||||
processed_summary = self._process_field_value(summary, "MemorySummary")
|
||||
processed_summary = await self._process_field_value(summary, "MemorySummary")
|
||||
memory_summary_list.extend(processed_summary)
|
||||
|
||||
# 处理Statement
|
||||
statement = data.get('statement')
|
||||
if statement is not None:
|
||||
processed_statement = self._process_field_value(statement, "Statement")
|
||||
processed_statement = await self._process_field_value(statement, "Statement")
|
||||
statement_list.extend(processed_statement)
|
||||
|
||||
# 处理ExtractedEntity
|
||||
extracted_entity = data.get('ExtractedEntity')
|
||||
if extracted_entity is not None:
|
||||
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
extracted_entity_list.extend(processed_entity)
|
||||
|
||||
# 去重 - 现在处理的是字典列表,需要更智能的去重
|
||||
@@ -128,6 +131,8 @@ class MemoryEntityService:
|
||||
all_timeline_data = memory_summary_list + statement_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
# 如果需要翻译(非中文),对整个结果进行翻译
|
||||
|
||||
result = {
|
||||
"MemorySummary": memory_summary_list,
|
||||
"Statement": statement_list,
|
||||
@@ -233,7 +238,7 @@ class MemoryEntityService:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理字段值,支持字符串、列表等类型
|
||||
|
||||
@@ -251,13 +256,13 @@ class MemoryEntityService:
|
||||
# 如果是列表,处理每个元素
|
||||
for item in value:
|
||||
if self._is_valid_item(item):
|
||||
processed_item = self._process_single_item(item)
|
||||
processed_item = await self._process_single_item(item)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, dict):
|
||||
# 如果是字典,直接处理
|
||||
if self._is_valid_item(value):
|
||||
processed_item = self._process_single_item(value)
|
||||
processed_item = await self._process_single_item(value)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, str):
|
||||
@@ -304,7 +309,7 @@ class MemoryEntityService:
|
||||
return (str(item).strip() != '' and
|
||||
"MemorySummaryChunk" not in str(item))
|
||||
|
||||
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理单个项目
|
||||
|
||||
@@ -369,6 +374,117 @@ class MemoryEntityService:
|
||||
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
|
||||
return str(dt) if dt is not None else None
|
||||
|
||||
async def _translate_list(
|
||||
self,
|
||||
data_list: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
fields: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
翻译列表中每个字典的指定字段(并发有限度以降低整体延迟)
|
||||
|
||||
Args:
|
||||
data_list: 要翻译的字典列表
|
||||
model_id: 模型ID
|
||||
fields: 需要翻译的字段列表
|
||||
|
||||
Returns:
|
||||
翻译后的字典列表
|
||||
"""
|
||||
# 空列表或无字段时直接返回
|
||||
if not data_list or not fields:
|
||||
return data_list
|
||||
|
||||
import asyncio
|
||||
|
||||
# 并发限制,避免一次性发起过多请求
|
||||
# 可根据实际情况调整(建议 5-10)
|
||||
concurrency_limit = 5
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def translate_single_field(
|
||||
index: int,
|
||||
field: str,
|
||||
value: Any,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
翻译单个字段并返回 (索引, 字段名, 翻译结果)
|
||||
|
||||
Returns:
|
||||
(index, field, translated_value) 或 None(如果跳过)
|
||||
"""
|
||||
# 跳过空值
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
# 统一转成字符串再翻译,防止非字符串类型导致错误
|
||||
text = str(value)
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
# 调用 Translation_English 进行翻译
|
||||
# 注意:Translation_English 的参数顺序是 (model_id, text)
|
||||
translated = await Translation_English(model_id, text)
|
||||
|
||||
# 如果翻译结果为空,保留原值
|
||||
if translated is None or translated == "":
|
||||
return None
|
||||
|
||||
return index, field, translated
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}")
|
||||
return None
|
||||
|
||||
# 构造所有需要翻译的任务
|
||||
tasks = []
|
||||
for idx, item in enumerate(data_list):
|
||||
# 防御性检查:确保 item 是字典
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
for field in fields:
|
||||
if field not in item:
|
||||
continue
|
||||
|
||||
value = item.get(field)
|
||||
|
||||
# 对于 None 或空字符串的值,直接跳过,不创建任务
|
||||
if value is None or value == "":
|
||||
continue
|
||||
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
translate_single_field(idx, field, value)
|
||||
)
|
||||
)
|
||||
|
||||
# 如果没有需要翻译的任务,直接返回原列表
|
||||
if not tasks:
|
||||
return data_list
|
||||
|
||||
# 使用 gather 并发执行翻译任务(受 semaphore 限制)
|
||||
# return_exceptions=True 可以防止单个任务失败导致整体失败
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 创建深拷贝以避免修改原始数据
|
||||
translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list]
|
||||
|
||||
# 将翻译结果回填到列表
|
||||
for result in results:
|
||||
# 跳过 None 结果和异常
|
||||
if result is None or isinstance(result, Exception):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"翻译任务异常: {result}")
|
||||
continue
|
||||
|
||||
idx, field, translated = result
|
||||
|
||||
# 防御性检查索引范围
|
||||
if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict):
|
||||
translated_list[idx][field] = translated
|
||||
|
||||
return translated_list
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -426,15 +542,19 @@ class MemoryEmotion:
|
||||
# 如果解析失败,返回原始字符串
|
||||
return iso_string
|
||||
|
||||
async def get_emotion(self) -> Dict[str, Any]:
|
||||
async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]:
|
||||
"""
|
||||
获取情绪随时间变化数据
|
||||
|
||||
Args:
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
包含情绪数据的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}")
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}")
|
||||
|
||||
if self.table == 'Statement':
|
||||
results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id)
|
||||
@@ -450,6 +570,10 @@ class MemoryEmotion:
|
||||
# 转换Neo4j类型
|
||||
final_data = self._convert_neo4j_types(emotion_data)
|
||||
|
||||
# 如果需要翻译(非中文)
|
||||
if language_type != 'zh' and model_id and final_data:
|
||||
final_data = await self._translate_emotion_data(final_data, model_id)
|
||||
|
||||
logger.info(f"成功获取 {len(final_data)} 条情绪数据")
|
||||
|
||||
return final_data
|
||||
@@ -590,16 +714,14 @@ class MemoryInteraction:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
||||
if ori_data!=[]:
|
||||
# name = ori_data[0]['name']
|
||||
group_id = ori_data[0]['group_id']
|
||||
group_id = [i['group_id'] for i in ori_data][0]
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
||||
if not Space_User:
|
||||
return []
|
||||
user_id=Space_User[0]['id']
|
||||
|
||||
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user