Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

View File

@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-
"""数据配置Repository模块
"""记忆配置Repository模块
本模块提供data_config表的数据访问层使用SQLAlchemy ORM进行数据库操作
本模块提供memory_config表的数据访问层使用SQLAlchemy ORM进行数据库操作
包括CRUD操作和Neo4j Cypher查询常量
Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作
MemoryConfigRepository: 记忆配置仓储类提供CRUD操作
"""
import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple
from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.models.memory_config_model import MemoryConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
@@ -28,11 +29,11 @@ db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "data_config"
class DataConfigRepository:
"""数据配置Repository
TABLE_NAME = "memory_config"
class MemoryConfigRepository:
"""记忆配置Repository
提供data_config表的数据访问方法包括
提供memory_config表的数据访问方法包括
- SQLAlchemy ORM 数据库操作
- Neo4j Cypher查询常量
"""
@@ -41,48 +42,48 @@ class DataConfigRepository:
# Dialogue count by group
SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# Chunk count by group
SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# Statement count by group
SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# ExtractedEntity count by group
SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num
"""
# All counts by label and total
SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN 'Statement' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN n.entity_idx AS entity_idx,
n.connect_strength AS connect_strength,
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary,
n.group_id AS group_id,
n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.id AS id
@@ -91,9 +92,9 @@ class DataConfigRepository:
# Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN
r.group_id AS group_id,
r.end_user_id AS end_user_id,
r.apply_id AS apply_id,
r.user_id AS user_id,
elementId(r) AS rel_id,
@@ -107,7 +108,7 @@ class DataConfigRepository:
@staticmethod
def update_reflection_config(
db: Session,
config_id: int,
config_id: uuid.UUID,
enable_self_reflexion: bool,
iteration_period: str,
reflexion_range: str,
@@ -115,7 +116,7 @@ class DataConfigRepository:
reflection_model_id: str,
memory_verify: bool,
quality_assessment: bool
) -> DataConfig:
) -> MemoryConfig:
"""构建反思配置更新语句SQLAlchemy text() 命名参数)
Args:
@@ -130,28 +131,28 @@ class DataConfigRepository:
config_id: 配置ID
Returns:
Data
MemoryConfig
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config_obj = db.scalars(stmt).first()
if not data_config_obj:
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
memory_config_obj = db.scalars(stmt).first()
if not memory_config_obj:
raise BusinessException
data_config_obj.enable_self_reflexion = enable_self_reflexion
data_config_obj.iteration_period = iteration_period
data_config_obj.reflexion_range = reflexion_range
data_config_obj.baseline = baseline
data_config_obj.reflection_model_id = reflection_model_id
data_config_obj.memory_verify = memory_verify
data_config_obj.quality_assessment = quality_assessment
memory_config_obj.enable_self_reflexion = enable_self_reflexion
memory_config_obj.iteration_period = iteration_period
memory_config_obj.reflexion_range = reflexion_range
memory_config_obj.baseline = baseline
memory_config_obj.reflection_model_id = reflection_model_id
memory_config_obj.memory_verify = memory_verify
memory_config_obj.quality_assessment = quality_assessment
return data_config_obj
return memory_config_obj
@staticmethod
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig:
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
@@ -162,13 +163,13 @@ class DataConfigRepository:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config = db.scalars(stmt).first()
if not data_config:
stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id)
memory_config = db.scalars(stmt).first()
if not memory_config:
raise RuntimeError("reflection config not found")
return data_config
return memory_config
@staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig:
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args:
@@ -180,11 +181,11 @@ class DataConfigRepository:
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id)
data_config = db.scalars(stmt).first()
if not data_config:
stmt = select(MemoryConfig).where(MemoryConfig.workspace_id == workspace_id)
memory_config = db.scalars(stmt).first()
if not memory_config:
raise RuntimeError("reflection config not found")
return data_config
return memory_config
@staticmethod
@@ -208,20 +209,21 @@ class DataConfigRepository:
return query, params
@staticmethod
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
"""创建数据配置
def create(db: Session, params: ConfigParamsCreate) -> MemoryConfig:
"""创建记忆配置
Args:
db: 数据库会话
params: 配置参数创建模型
Returns:
DataConfig: 创建的配置对象
MemoryConfig: 创建的配置对象
"""
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
db_logger.debug(f"创建记忆配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
try:
db_config = DataConfig(
db_config = MemoryConfig(
config_id=uuid.uuid4(),
config_name=params.config_name,
config_desc=params.config_desc,
workspace_id=params.workspace_id,
@@ -232,16 +234,16 @@ class DataConfigRepository:
db.add(db_config)
db.flush() # 获取自增ID但不提交事务
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
db_logger.info(f"记忆配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
db_logger.error(f"创建记忆配置失败: {params.config_name} - {str(e)}")
raise
@staticmethod
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
def update(db: Session, update: ConfigUpdate) -> Optional[MemoryConfig]:
"""更新基础配置
Args:
@@ -249,17 +251,17 @@ class DataConfigRepository:
update: 配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
db_logger.debug(f"更新记忆配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段
@@ -277,17 +279,17 @@ class DataConfigRepository:
db.commit()
db.refresh(db_config)
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
db_logger.info(f"记忆配置更新成功: {db_config.config_name} (ID: {update.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
"""更新记忆萃取引擎配置
Args:
@@ -295,7 +297,7 @@ class DataConfigRepository:
update: 萃取配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
@@ -303,9 +305,9 @@ class DataConfigRepository:
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段映射
@@ -360,7 +362,7 @@ class DataConfigRepository:
raise
@staticmethod
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[MemoryConfig]:
"""更新遗忘引擎配置
Args:
@@ -368,7 +370,7 @@ class DataConfigRepository:
update: 遗忘配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Optional[MemoryConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
@@ -376,9 +378,9 @@ class DataConfigRepository:
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
db_logger.warning(f"记忆配置不存在: config_id={update.config_id}")
return None
# 更新字段
@@ -408,7 +410,7 @@ class DataConfigRepository:
raise
@staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
@@ -421,7 +423,7 @@ class DataConfigRepository:
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
return None
@@ -457,7 +459,7 @@ class DataConfigRepository:
raise
@staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置
Args:
@@ -470,7 +472,7 @@ class DataConfigRepository:
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
return None
@@ -489,39 +491,39 @@ class DataConfigRepository:
raise
@staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
"""根据ID获取数据配置
def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]:
"""根据ID获取记忆配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[DataConfig]: 配置对象不存在则返回None
Optional[MemoryConfig]: 配置对象不存在则返回None
"""
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
db_logger.debug(f"根据ID查询记忆配置: config_id={config_id}")
try:
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config:
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
db_logger.debug(f"记忆配置查询成功: {config.config_name} (ID: {config_id})")
else:
db_logger.debug(f"数据配置不存在: config_id={config_id}")
db_logger.debug(f"记忆配置不存在: config_id={config_id}")
return config
except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
"""Get data config and its associated workspace information
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
"""Get memory config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
@@ -541,19 +543,19 @@ class DataConfigRepository:
}
)
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first()
result = db.query(MemoryConfig, Workspace).join(
Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(MemoryConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if config_only:
if config_only.workspace_id is None:
config_logger.error(
@@ -566,7 +568,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} has no associated workspace ID")
db_logger.error(f"Memory config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace")
else:
config_logger.error(
@@ -579,7 +581,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
@@ -591,7 +593,7 @@ class DataConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config not found: config_id={config_id}")
db_logger.debug(f"Memory config not found: config_id={config_id}")
return None
config, workspace = result
@@ -611,7 +613,7 @@ class DataConfigRepository:
}
)
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
@@ -633,10 +635,10 @@ class DataConfigRepository:
exc_info=True
)
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数
Args:
@@ -644,17 +646,17 @@ class DataConfigRepository:
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[DataConfig]: 配置列表
List[MemoryConfig]: 配置列表
"""
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(DataConfig)
query = db.query(MemoryConfig)
if workspace_id:
query = query.filter(DataConfig.workspace_id == workspace_id)
query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(DataConfig.updated_at)).all()
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
@@ -664,8 +666,8 @@ class DataConfigRepository:
raise
@staticmethod
def delete(db: Session, config_id: int) -> bool:
"""删除数据配置
def delete(db: Session, config_id: uuid.UUID) -> bool:
"""删除记忆配置
Args:
db: 数据库会话
@@ -674,22 +676,22 @@ class DataConfigRepository:
Returns:
bool: 删除成功返回True配置不存在返回False
"""
db_logger.debug(f"删除数据配置: config_id={config_id}")
db_logger.debug(f"删除记忆配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={config_id}")
db_logger.warning(f"记忆配置不存在: config_id={config_id}")
return False
db.delete(db_config)
db.commit()
db_logger.info(f"数据配置删除成功: config_id={config_id}")
db_logger.info(f"记忆配置删除成功: config_id={config_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
db_logger.error(f"删除记忆配置失败: config_id={config_id} - {str(e)}")
raise

View File

@@ -6,7 +6,7 @@ from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageService
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
db_logger = get_db_logger()
@@ -28,7 +28,7 @@ class MemoryPerceptualRepository:
file_ext: str,
summary: Optional[str] = None,
meta_data: Optional[dict] = None,
storage_service: FileStorageType = FileStorageType.LOCAL
storage_service: FileStorageService = FileStorageService.LOCAL
) -> MemoryPerceptualModel:

View File

@@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"id": stable_edge_id,
"source": chunk.id,
"target": stmt.id,
"group_id": getattr(stmt, 'group_id', None),
"end_user_id": getattr(stmt, 'end_user_id', None),
"user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
@@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
"group_id": s.group_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully")
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
@@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"group_id": dialogue.group_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"end_user_id": dialogue.end_user_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
@@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
flattened_statement = {
"id": statement.id,
"name": statement.name,
"group_id": statement.group_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"end_user_id": statement.end_user_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
@@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"group_id": chunk.group_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"end_user_id": chunk.end_user_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
@@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
flattened.append({
"id": s.id,
"name": s.name,
"group_id": s.group_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,

View File

@@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... {"end_user_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""

View File

@@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.end_user_id = dialogue.end_user_id,
n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at,
@@ -22,9 +20,7 @@ SET s += {
id: statement.id,
run_id: statement.run_id,
chunk_id: statement.chunk_id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
end_user_id: statement.end_user_id,
stmt_type: statement.stmt_type,
statement: statement.statement,
emotion_intensity: statement.emotion_intensity,
@@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id})
SET c += {
id: chunk.id,
name: chunk.name,
group_id: chunk.group_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
end_user_id: chunk.end_user_id,
run_id: chunk.run_id,
created_at: chunk.created_at,
expired_at: chunk.expired_at,
@@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """
UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
@@ -134,9 +126,9 @@ RETURN e.id AS uuid
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
// Match entities by stable id within end_user_id, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id})
MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
@@ -148,7 +140,7 @@ SET r.predicate = rel.predicate,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id,
r.group_id = rel.group_id
r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid
"""
@@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
group_id: entity.group_id,
end_user_id: entity.end_user_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
@@ -175,11 +167,11 @@ RETURN e.id AS id
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
@@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.group_id = edge.group_id,
e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
@@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id,
SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
@@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """
STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
// Entities are shared across runs within end_user_id; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id,
SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
@@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id)
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
@@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id)
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id)
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
@@ -292,12 +283,12 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
@@ -316,15 +307,13 @@ LIMIT $limit
# 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id)
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
@@ -347,11 +336,11 @@ LIMIT $limit
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
@@ -413,10 +402,10 @@ LIMIT $limit
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id)
WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
AND d.id = $dialog_id
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at
@@ -426,10 +415,10 @@ LIMIT $limit
SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id)
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
AND c.id = $chunk_id
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.created_at AS created_at,
@@ -441,18 +430,14 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -468,9 +453,7 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
@@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
@@ -499,15 +480,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -519,15 +496,11 @@ LIMIT $limit
SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -539,15 +512,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -559,15 +528,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -579,15 +544,11 @@ LIMIT $limit
SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -599,15 +560,11 @@ LIMIT $limit
SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.end_user_id AS end_user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
@@ -665,18 +622,18 @@ LIMIT $limit
# 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id})
MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id)
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id)
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
@@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id})
SET m += {
id: summary.id,
name: summary.name,
group_id: summary.group_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
end_user_id: summary.end_user_id,
run_id: summary.run_id,
created_at: summary.created_at,
expired_at: summary.expired_at,
@@ -745,7 +700,7 @@ MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.group_id = e.group_id,
SET r.end_user_id = e.end_user_id,
r.run_id = e.run_id,
r.created_at = e.created_at,
r.expired_at = e.expired_at
@@ -774,7 +729,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at,
invalid_at: rel.invalid_at,
group_id: rel.group_id,
end_user_id: rel.end_user_id,
user_id: rel.user_id,
apply_id: rel.apply_id,
run_id: rel.run_id,
@@ -796,7 +751,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
source_statement_id: rel.source_statement_id,
valid_at: rel.valid_at,
invalid_at: rel.invalid_at,
group_id: rel.group_id,
end_user_id: rel.end_user_id,
user_id: rel.user_id,
apply_id: rel.apply_id,
run_id: rel.run_id,
@@ -814,7 +769,7 @@ RETURN count(losing) as deleted
neo4j_statement_part = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN
n.statement as statement_name,
@@ -824,7 +779,7 @@ RETURN
'''
neo4j_statement_all = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
RETURN
n.statement as statement_name,
n.id as statement_id
@@ -832,7 +787,7 @@ RETURN
'''
neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
@@ -853,7 +808,7 @@ neo4j_query_part = """
"""
neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WHERE n.end_user_id = "{}"
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
@@ -1027,14 +982,14 @@ RETURN DISTINCT
Memory_Space_User="""
MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户"
WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id
"""
Memory_Space_Entity="""
MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN
DISTINCT m.name as name,m.group_id as group_id
DISTINCT m.name as name,m.end_user_id as end_user_id
"""
Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h)

View File

@@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
提供按end_user_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
@@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据end_user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
return await self.find({"end_user_id": end_user_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
@@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
"""根据end_user_id和user_id查询对话
Args:
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
@@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
{"end_user_id": end_user_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
end_user_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
@@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
@@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
days=days,
limit=limit
)
@@ -164,22 +164,22 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
end_user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
"""根据config_id和end_user_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
end_user_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"config_id": config_id, "group_id": group_id},
{"config_id": config_id, "end_user_id": end_user_id},
limit=limit
)

View File

@@ -40,7 +40,7 @@ class EmotionRepository:
async def get_emotion_tags(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
@@ -51,7 +51,7 @@ class EmotionRepository:
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串
@@ -65,8 +65,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -119,7 +119,7 @@ class EmotionRepository:
async def get_emotion_wordcloud(
self,
group_id: str,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> List[Dict[str, Any]]:
@@ -128,7 +128,7 @@ class EmotionRepository:
查询情绪关键词及其频率,用于生成词云可视化。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
@@ -140,8 +140,8 @@ class EmotionRepository:
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"]
params = {"end_user_id": end_user_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
@@ -186,7 +186,7 @@ class EmotionRepository:
async def get_emotions_in_range(
self,
group_id: str,
end_user_id: str,
time_range: str = "30d"
) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据
@@ -194,7 +194,7 @@ class EmotionRepository:
查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args:
group_id: 用户组ID宿主ID
end_user_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d
Returns:
@@ -214,7 +214,7 @@ class EmotionRepository:
# 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """
MATCH (s:Statement)
WHERE s.group_id = $group_id
WHERE s.end_user_id = $end_user_id
AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date
RETURN s.id as statement_id,
@@ -227,7 +227,7 @@ class EmotionRepository:
try:
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
start_date=start_date
)
formatted_results = [

View File

@@ -44,9 +44,7 @@ async def save_entities_and_relationships(
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'group_id': edge.group_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
'end_user_id': edge.end_user_id,
}
all_relationships.append(relationship)
@@ -101,9 +99,7 @@ async def save_statement_chunk_edges(
"id": edge.id,
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
@@ -132,9 +128,7 @@ async def save_statement_entity_edges(
edge_data = {
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,

View File

@@ -33,7 +33,7 @@ async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
@@ -46,7 +46,7 @@ async def _update_activation_values_batch(
connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选
end_user_id: 组ID可选
max_retries: 最大重试次数
Returns:
@@ -97,7 +97,7 @@ async def _update_activation_values_batch(
updated_nodes = await access_manager.record_batch_access(
node_ids=unique_node_ids,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -118,7 +118,7 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -129,7 +129,7 @@ async def _update_search_results_activation(
Args:
connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
@@ -152,7 +152,7 @@ async def _update_search_results_activation(
connector=connector,
nodes=results[key],
node_label=label,
group_id=group_id
end_user_id=end_user_id
)
)
update_keys.append(key)
@@ -218,7 +218,7 @@ async def _update_search_results_activation(
async def search_graph(
connector: Neo4jConnector,
q: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -236,7 +236,7 @@ async def search_graph(
Args:
connector: Neo4j connector
q: Query text
group_id: Optional group filter
end_user_id: Optional group filter
limit: Max results per category
include: List of categories to search (default: all)
@@ -254,7 +254,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -263,7 +263,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -272,7 +272,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -281,7 +281,7 @@ async def search_graph(
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -310,12 +310,12 @@ async def search_graph(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -325,7 +325,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
@@ -337,7 +337,7 @@ async def search_graph_by_embedding(
- Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher
- Filters by group_id if provided
- Filters by end_user_id if provided
- Returns up to 'limit' per included type
"""
import time
@@ -346,7 +346,7 @@ async def search_graph_by_embedding(
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s")
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
@@ -361,7 +361,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
@@ -371,7 +371,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
@@ -381,7 +381,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
@@ -391,7 +391,7 @@ async def search_graph_by_embedding(
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
@@ -400,7 +400,7 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
@@ -429,13 +429,13 @@ async def search_graph_by_embedding(
key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks']
)
if needs_activation_update:
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
update_time = time.time() - update_start
logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s")
@@ -445,7 +445,7 @@ async def search_graph_by_embedding(
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
@@ -453,7 +453,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
@@ -477,7 +477,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name,
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
except Exception:
@@ -501,7 +501,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name.lower(),
group_id=group_id,
end_user_id=end_user_id,
limit=100,
)
for r in rows:
@@ -532,9 +532,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -547,32 +545,30 @@ async def search_graph_by_keyword_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
if not query_text:
logger.warning(f"query_text cannot be empty")
print(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
logger.debug(f"Temporal keyword search results: {len(statements)} statements found")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -580,9 +576,7 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -595,14 +589,12 @@ async def search_graph_by_temporal(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
@@ -610,16 +602,16 @@ async def search_graph_by_temporal(
limit=limit,
)
logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}")
logger.debug(f"Temporal search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
@@ -628,23 +620,23 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
- Matches dialogues with dialog_id
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' dialogues
"""
if not dialog_id:
logger.warning(f"dialog_id cannot be empty")
print(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id,
end_user_id=end_user_id,
dialog_id=dialog_id,
limit=limit,
)
@@ -654,15 +646,15 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
logger.warning(f"chunk_id cannot be empty")
print(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id,
end_user_id=end_user_id,
chunk_id=chunk_id,
limit=limit,
)
@@ -671,9 +663,9 @@ async def search_graph_by_chunk_id(
async def search_graph_by_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -683,37 +675,37 @@ async def search_graph_by_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -723,37 +715,37 @@ async def search_graph_by_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -763,37 +755,37 @@ async def search_graph_g_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -803,37 +795,37 @@ async def search_graph_g_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -843,37 +835,37 @@ async def search_graph_l_created_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
@@ -883,28 +875,28 @@ async def search_graph_l_valid_at(
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Optionally filters by end_user_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}")
logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}")
logger.debug(f"Search results: {len(statements)} statements found")
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
end_user_id=end_user_id
)
return results

View File

@@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository
Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges.
Provides methods to query summaries by end_user_id, user_id, and time ranges.
Attributes:
connector: Neo4j connector instance
@@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository):
return dict(n)
async def find_by_group_id(
async def find_by_end_user_id(
self,
group_id: str,
end_user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id
"""Query memory summaries by end_user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
@@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
"""
params = {"group_id": group_id, "limit": limit}
params = {"end_user_id": end_user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_group_and_user(
self,
group_id: str,
end_user_id: str,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id
"""Query memory summaries by both end_user_id and user_id
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
@@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id
WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id
"""
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
@@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_recent_summaries(
self,
group_id: str,
end_user_id: str,
days: int = 7,
limit: int = 1000
) -> List[Dict[str, Any]]:
"""Query recent memory summaries
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
days: Number of recent days to query
limit: Maximum number of results to return
@@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
@@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query(
query,
group_id=group_id,
end_user_id=end_user_id,
days=days,
limit=limit
)
@@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
async def find_by_content_keywords(
self,
group_id: str,
end_user_id: str,
keywords: List[str],
limit: int = 100
) -> List[Dict[str, Any]]:
"""Query memory summaries by content keywords
Args:
group_id: Group ID to filter by
end_user_id: Group ID to filter by
keywords: List of keywords to search for in content
limit: Maximum number of results to return
@@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
"""
# Build keyword search conditions
keyword_conditions = []
params = {"group_id": group_id, "limit": limit}
params = {"end_user_id": end_user_id, "limit": limit}
for i, keyword in enumerate(keywords):
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
@@ -243,7 +243,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
AND ({keyword_filter})
RETURN n
ORDER BY n.created_at DESC
@@ -253,21 +253,21 @@ class MemorySummaryRepository(BaseNeo4jRepository):
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def get_summary_count_by_group(self, group_id: str) -> int:
async def get_summary_count_by_group(self, end_user_id: str) -> int:
"""Get count of memory summaries for a group
Args:
group_id: Group ID to count summaries for
end_user_id: Group ID to count summaries for
Returns:
int: Number of memory summaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
results = await self.connector.execute_query(query, group_id=group_id)
results = await self.connector.execute_query(query, end_user_id=end_user_id)
return results[0]['count'] if results else 0

View File

@@ -70,11 +70,7 @@ class Neo4jConnector:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example:
>>> connector = Neo4jConnector()
>>> results = await connector.execute_query(
... "MATCH (n:Person {name: $name}) RETURN n",
... name="Alice"
... )
"""
result = await self.driver.execute_query(
query,
@@ -98,17 +94,7 @@ class Neo4jConnector:
Any: 事务函数的返回值
Example:
>>> async def create_node(tx, name):
... result = await tx.run(
... "CREATE (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_write_transaction(
... create_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_write(transaction_func, **kwargs)
@@ -126,45 +112,33 @@ class Neo4jConnector:
Any: 事务函数的返回值
Example:
>>> async def get_node(tx, name):
... result = await tx.run(
... "MATCH (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_read_transaction(
... get_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str):
async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据
删除所有属于指定group_id的节点和边。
删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
group_id: 要删除的组ID
end_user_id: 要删除的组ID
Example:
>>> connector = Neo4jConnector()
>>> await connector.delete_group("group_123")
Group group_123 deleted.
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j",
group_id=group_id
end_user_id=end_user_id
)
print(f"Group {group_id} deleted.")
print(f"Group {end_user_id} deleted.")

View File

@@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。
Attributes:
connector: Neo4j连接器实例