97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""陈述句仓储模块
|
||
|
||
本模块提供陈述句节点的数据访问功能。
|
||
|
||
Classes:
|
||
StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作
|
||
"""
|
||
|
||
from typing import List, Dict
|
||
from datetime import datetime
|
||
|
||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||
from app.core.memory.models.graph_models import StatementNode
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||
|
||
|
||
class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||
"""陈述句仓储
|
||
|
||
管理陈述句节点的创建、查询、更新和删除操作。
|
||
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
|
||
|
||
Attributes:
|
||
connector: Neo4j连接器实例
|
||
node_label: 节点标签,固定为"Statement"
|
||
"""
|
||
|
||
def __init__(self, connector: Neo4jConnector):
|
||
"""初始化陈述句仓储
|
||
|
||
Args:
|
||
connector: Neo4j连接器实例
|
||
"""
|
||
super().__init__(connector, "Statement")
|
||
|
||
def _map_to_entity(self, node_data: Dict) -> StatementNode:
|
||
"""将节点数据映射为陈述句实体
|
||
|
||
Args:
|
||
node_data: 从Neo4j查询返回的节点数据字典
|
||
|
||
Returns:
|
||
StatementNode: 陈述句实体对象
|
||
"""
|
||
# 从查询结果中提取节点数据
|
||
n = node_data.get('n', node_data)
|
||
|
||
# 处理datetime字段
|
||
if isinstance(n.get('created_at'), str):
|
||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||
if n.get('valid_at') and isinstance(n['valid_at'], str):
|
||
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
|
||
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
|
||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||
|
||
# 处理temporal_info字段
|
||
if isinstance(n.get('temporal_info'), str):
|
||
# 从字符串转换为枚举值
|
||
n['temporal_info'] = TemporalInfo(n['temporal_info'])
|
||
elif isinstance(n.get('temporal_info'), dict):
|
||
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
|
||
elif not n.get('temporal_info'):
|
||
# 如果没有temporal_info,创建一个默认的
|
||
n['temporal_info'] = TemporalInfo.STATIC
|
||
|
||
# 处理情绪字段 - 映射 Neo4j 节点属性到 StatementNode 模型
|
||
# 处理空值情况,确保字段存在
|
||
n['emotion_type'] = n.get('emotion_type')
|
||
n['emotion_intensity'] = n.get('emotion_intensity')
|
||
n['emotion_keywords'] = n.get('emotion_keywords', [])
|
||
n['emotion_subject'] = n.get('emotion_subject')
|
||
n['emotion_target'] = n.get('emotion_target')
|
||
|
||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||
n['importance_score'] = n.get('importance_score', 0.5)
|
||
n['activation_value'] = n.get('activation_value')
|
||
n['access_history'] = n.get('access_history') or []
|
||
n['last_access_time'] = n.get('last_access_time')
|
||
n['access_count'] = n.get('access_count', 0)
|
||
|
||
return StatementNode(**n)
|
||
|
||
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
|
||
"""根据chunk_id查询陈述句
|
||
|
||
Args:
|
||
chunk_id: 分块ID
|
||
|
||
Returns:
|
||
List[StatementNode]: 陈述句列表
|
||
"""
|
||
return await self.find({"chunk_id": chunk_id})
|