Fix/develop memory deail (#63)
* 新增记忆空间详情 * 新增记忆空间详情 * 新增记忆关联的数量 * 修改记忆时间线 * 修改记忆时间线 * 修改记忆时间线 * Parameterize elementId in Cypher query --------- Co-authored-by: Ke Sun <33739460+keeees@users.noreply.github.com>
This commit is contained in:
@@ -24,9 +24,6 @@ class MemoryEntityService:
|
||||
self.id = id
|
||||
self.table = table
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
|
||||
|
||||
async def get_timeline_memories_server(self):
|
||||
"""
|
||||
获取时间线记忆数据
|
||||
@@ -135,13 +132,14 @@ class MemoryEntityService:
|
||||
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
extracted_entity_list.extend(processed_entity)
|
||||
|
||||
# 去重
|
||||
memory_summary_list = list(set(memory_summary_list))
|
||||
statement_list = list(set(statement_list))
|
||||
extracted_entity_list = list(set(extracted_entity_list))
|
||||
# 去重 - 现在处理的是字典列表,需要更智能的去重
|
||||
memory_summary_list = self._deduplicate_dict_list(memory_summary_list)
|
||||
statement_list = self._deduplicate_dict_list(statement_list)
|
||||
extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list)
|
||||
|
||||
# 合并所有数据
|
||||
# 合并所有数据并处理相同text的合并
|
||||
all_timeline_data = memory_summary_list + statement_list + extracted_entity_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
result = {
|
||||
"MemorySummary": memory_summary_list,
|
||||
@@ -154,7 +152,101 @@ class MemoryEntityService:
|
||||
|
||||
return result
|
||||
|
||||
def _process_field_value(self, value: Any, field_name: str) -> List[str]:
|
||||
def _deduplicate_dict_list(self, dict_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对字典列表进行去重
|
||||
|
||||
Args:
|
||||
dict_list: 字典列表
|
||||
|
||||
Returns:
|
||||
去重后的字典列表
|
||||
"""
|
||||
seen = set()
|
||||
result = []
|
||||
|
||||
for item in dict_list:
|
||||
# 使用text作为去重的键
|
||||
text = item.get('text', '')
|
||||
if text and text not in seen:
|
||||
seen.add(text)
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
def _merge_same_text_items(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
合并具有相同text的项目,合并type字段,保留一个时间
|
||||
|
||||
Args:
|
||||
items: 项目列表
|
||||
|
||||
Returns:
|
||||
合并后的项目列表
|
||||
"""
|
||||
text_groups = {}
|
||||
|
||||
# 按text分组
|
||||
for item in items:
|
||||
text = item.get('text', '')
|
||||
if not text:
|
||||
continue
|
||||
|
||||
if text not in text_groups:
|
||||
text_groups[text] = {
|
||||
'text': text,
|
||||
'types': set(),
|
||||
'created_at': item.get('created_at'),
|
||||
'latest_time': item.get('created_at')
|
||||
}
|
||||
|
||||
# 添加type到集合中
|
||||
item_type = item.get('type')
|
||||
if item_type:
|
||||
text_groups[text]['types'].add(item_type)
|
||||
|
||||
# 保留最新的时间(如果有的话)
|
||||
current_time = item.get('created_at')
|
||||
if current_time and (not text_groups[text]['latest_time'] or
|
||||
self._is_later_time(current_time, text_groups[text]['latest_time'])):
|
||||
text_groups[text]['latest_time'] = current_time
|
||||
|
||||
# 转换为最终格式
|
||||
result = []
|
||||
for text, group_data in text_groups.items():
|
||||
merged_item = {
|
||||
'text': text,
|
||||
'type': ', '.join(sorted(group_data['types'])), # 合并多个type
|
||||
'created_at': group_data['latest_time']
|
||||
}
|
||||
result.append(merged_item)
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
result.sort(key=lambda x: x.get('created_at', ''), reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
def _is_later_time(self, time1: str, time2: str) -> bool:
|
||||
"""
|
||||
比较两个时间字符串,判断time1是否晚于time2
|
||||
|
||||
Args:
|
||||
time1: 时间字符串1
|
||||
time2: 时间字符串2
|
||||
|
||||
Returns:
|
||||
time1是否晚于time2
|
||||
"""
|
||||
try:
|
||||
if not time1 or not time2:
|
||||
return bool(time1) # 如果time2为空,time1存在就算更晚
|
||||
|
||||
# 简单的字符串比较(适用于ISO格式的时间)
|
||||
return time1 > time2
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理字段值,支持字符串、列表等类型
|
||||
|
||||
@@ -163,30 +255,133 @@ class MemoryEntityService:
|
||||
field_name: 字段名称(用于日志)
|
||||
|
||||
Returns:
|
||||
处理后的字符串列表
|
||||
处理后的字典列表
|
||||
"""
|
||||
processed_values = []
|
||||
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
# 如果是列表,处理每个元素
|
||||
for item in value:
|
||||
if item is not None and str(item).strip() != '' and "MemorySummaryChunk" not in str(item):
|
||||
processed_values.append(str(item))
|
||||
if self._is_valid_item(item):
|
||||
processed_item = self._process_single_item(item)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, dict):
|
||||
# 如果是字典,直接处理
|
||||
if self._is_valid_item(value):
|
||||
processed_item = self._process_single_item(value)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, str):
|
||||
# 如果是字符串,直接处理
|
||||
# 如果是字符串,转换为字典格式
|
||||
if value.strip() != '' and "MemorySummaryChunk" not in value:
|
||||
processed_values.append(value)
|
||||
processed_values.append({
|
||||
'text': value,
|
||||
'type': field_name,
|
||||
'created_at': None
|
||||
})
|
||||
elif value is not None:
|
||||
# 其他类型转换为字符串
|
||||
str_value = str(value)
|
||||
if str_value.strip() != '' and "MemorySummaryChunk" not in str_value:
|
||||
processed_values.append(str_value)
|
||||
processed_values.append({
|
||||
'text': str_value,
|
||||
'type': field_name,
|
||||
'created_at': None
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"处理字段 {field_name} 的值时出错: {e}, 值类型: {type(value)}, 值: {value}")
|
||||
|
||||
return processed_values
|
||||
|
||||
def _is_valid_item(self, item: Any) -> bool:
|
||||
"""
|
||||
检查项目是否有效
|
||||
|
||||
Args:
|
||||
item: 要检查的项目
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if item is None:
|
||||
return False
|
||||
|
||||
if isinstance(item, dict):
|
||||
text = item.get('text')
|
||||
return (text is not None and
|
||||
str(text).strip() != '' and
|
||||
"MemorySummaryChunk" not in str(text))
|
||||
|
||||
return (str(item).strip() != '' and
|
||||
"MemorySummaryChunk" not in str(item))
|
||||
|
||||
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理单个项目
|
||||
|
||||
Args:
|
||||
item: 要处理的项目字典
|
||||
|
||||
Returns:
|
||||
处理后的项目字典
|
||||
"""
|
||||
try:
|
||||
text = item.get('text')
|
||||
created_at = item.get('created_at')
|
||||
item_type = item.get('type', '未知类型')
|
||||
|
||||
# 转换Neo4j时间格式
|
||||
formatted_time = self._convert_neo4j_datetime(created_at)
|
||||
|
||||
return {
|
||||
'text': text,
|
||||
'type': item_type,
|
||||
'created_at': formatted_time
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"处理单个项目时出错: {e}, 项目: {item}")
|
||||
return None
|
||||
|
||||
def _convert_neo4j_datetime(self, dt: Any) -> str:
|
||||
"""
|
||||
转换Neo4j时间格式为标准时间字符串
|
||||
|
||||
Args:
|
||||
dt: Neo4j时间对象或其他时间格式
|
||||
|
||||
Returns:
|
||||
格式化的时间字符串
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 处理Neo4j DateTime对象
|
||||
if isinstance(dt, Neo4jDateTime):
|
||||
return dt.iso_format().replace('T', ' ').split('.')[0]
|
||||
|
||||
# 处理其他neo4j时间类型
|
||||
if hasattr(dt, 'iso_format'):
|
||||
return dt.iso_format().replace('T', ' ').split('.')[0]
|
||||
|
||||
# 处理字符串格式的时间
|
||||
if isinstance(dt, str):
|
||||
# 尝试解析ISO格式
|
||||
try:
|
||||
parsed_dt = datetime.fromisoformat(dt.replace('Z', '+00:00'))
|
||||
return parsed_dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return dt
|
||||
|
||||
# 其他情况直接转换为字符串
|
||||
return str(dt)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
|
||||
return str(dt) if dt is not None else None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1394,7 +1394,7 @@ async def analytics_graph_data(
|
||||
"group_id": end_user_id,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
# 执行节点查询
|
||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||
|
||||
@@ -1409,7 +1409,7 @@ async def analytics_graph_data(
|
||||
node_props = record["properties"]
|
||||
|
||||
# 根据节点类型提取需要的属性字段
|
||||
filtered_props = _extract_node_properties(node_label, node_props)
|
||||
filtered_props = await _extract_node_properties(node_label, node_props,node_id)
|
||||
|
||||
# 直接使用数据库中的 caption,如果没有则使用节点类型作为默认值
|
||||
caption = filtered_props.get("caption", node_label)
|
||||
@@ -1515,7 +1515,7 @@ async def analytics_graph_data(
|
||||
|
||||
# 辅助函数
|
||||
|
||||
def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据节点类型提取需要的属性字段
|
||||
|
||||
@@ -1542,7 +1542,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
|
||||
if not allowed_fields:
|
||||
# 对于未定义的节点类型,只返回基本字段
|
||||
allowed_fields = ["name", "created_at", "caption"]
|
||||
|
||||
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
|
||||
node_results = await (_neo4j_connector.execute_query(count_neo4j))
|
||||
# 提取白名单中的字段
|
||||
filtered_props = {}
|
||||
for field in allowed_fields:
|
||||
@@ -1550,7 +1551,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str
|
||||
value = properties[field]
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
|
||||
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
|
||||
print(filtered_props)
|
||||
return filtered_props
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user