Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""Neo4j仓储模块
本模块包含Neo4j图数据库的仓储实现用于管理知识图谱的节点和边。
Modules:
neo4j_connector: Neo4j数据库连接器
base_neo4j_repository: Neo4j仓储基类
dialog_repository: 对话仓储
statement_repository: 陈述句仓储
entity_repository: 实体仓储
cypher_queries: Cypher查询语句
graph_search: 图搜索功能
graph_saver: 图数据保存功能
add_nodes: 添加节点功能
add_edges: 添加边功能
create_indexes: 创建索引功能
"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.repositories.neo4j.dialog_repository import DialogRepository
from app.repositories.neo4j.statement_repository import StatementRepository
from app.repositories.neo4j.entity_repository import EntityRepository
__all__ = [
'Neo4jConnector',
'BaseNeo4jRepository',
'DialogRepository',
'StatementRepository',
'EntityRepository',
]

View File

@@ -0,0 +1,102 @@
from typing import List, Optional
import hashlib
from datetime import datetime
from uuid import uuid4
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE
from app.core.memory.models.message_models import Chunk
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import MemorySummaryNode
async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add edges between chunk nodes and their statement nodes in Neo4j.
Args:
chunks: List of Chunk objects containing the statements
connector: Neo4j connector instance
Returns:
List of created edge UUIDs or None if failed
"""
if not chunks:
print("No chunks provided to create edges")
return []
try:
# Build edges deterministically per (chunk, statement) pair
edges: List[dict] = []
for chunk in chunks:
for stmt in getattr(chunk, "statements", []) or []:
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
edge = {
"id": stable_edge_id,
"source": chunk.id,
"target": stmt.id,
"group_id": getattr(stmt, 'group_id', None),
"user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
"created_at": getattr(stmt, 'created_at', None),
"expired_at": getattr(stmt, 'expired_at', None),
# "created_at": getattr(statement, 'created_at', None),
# "expired_at": None # Set to None or appropriate default
}
edges.append(edge)
if not edges:
print("No statements found in chunks to create edges")
return []
# Execute the query to create edges
result = await connector.execute_query(
CHUNK_STATEMENT_EDGE_SAVE,
chunk_statement_edges=edges
)
created_uuids = [record.get("uuid") for record in result] if result else []
print(f"Successfully created {len(created_uuids)} chunk-statement edges")
return created_uuids
except Exception as e:
print(f"Error creating chunk-statement edges: {e}")
return None
async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Create edges from MemorySummary to Statements via their chunk_ids.
For each summary and each chunk_id in it, this links the summary to all statements
contained in that chunk using DERIVED_FROM_STATEMENT. This supports queries like
summary -> statement -> entity with minimal hops.
Args:
summaries: List of MemorySummaryNode objects
connector: Neo4j connector instance
Returns:
List of created edge elementIds or None if failed
"""
if not summaries:
return []
try:
edges: List[dict] = []
for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []:
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
"group_id": s.group_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
})
if not edges:
return []
result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges
)
created = [record.get("uuid") for record in result] if result else []
return created
except Exception:
return None

View File

@@ -0,0 +1,215 @@
from typing import List, Optional
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database.
Args:
dialogues: List of DialogueNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not dialogues:
print("No dialogues to save")
return []
try:
# Flatten DialogueNode objects to match Cypher expected fields
flattened_dialogues = []
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"group_id": dialogue.group_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
"content": dialogue.content,
"dialog_embedding": dialogue.dialog_embedding
})
result = await connector.execute_query(
DIALOGUE_NODE_SAVE,
dialogues=flattened_dialogues
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids
except Exception as e:
print(f"Error creating dialogue nodes: {e}")
return None
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add statement nodes to Neo4j database.
Args:
statements: List of StatementNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not statements:
print("No statements to save")
return []
try:
# Flatten StatementNode objects to only include primitive types
flattened_statements = []
for statement in statements:
flattened_statement = {
"id": statement.id,
"name": statement.name,
"group_id": statement.group_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
"created_at": statement.created_at.isoformat() if statement.created_at else None,
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
"stmt_type": statement.stmt_type,
"temporal_info": statement.temporal_info.value,
"statement": statement.statement,
"connect_strength": statement.connect_strength,
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
# "triplet_extraction_info": json.dumps({
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
}
flattened_statements.append(flattened_statement)
result = await connector.execute_query(
STATEMENT_NODE_SAVE,
statements=flattened_statements
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids
except Exception as e:
print(f"Error creating statement nodes: {e}")
return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch.
Args:
chunks: List of ChunkNode objects to add
connector: Neo4j connector instance
Returns:
List of created chunk UUIDs or None if failed
"""
if not chunks:
print("No chunk nodes to add")
return []
try:
# Convert chunk nodes to dictionaries for the query
flattened_chunks = []
for chunk in chunks:
# Flatten metadata properties to avoid Neo4j Map type issues
metadata = chunk.metadata if chunk.metadata else {}
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"group_id": chunk.group_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
"dialog_id": chunk.dialog_id,
"content": chunk.content,
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index")
}
flattened_chunks.append(flattened_chunk)
result = await connector.execute_query(
CHUNK_NODE_SAVE,
chunks=flattened_chunks
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids
except Exception as e:
print(f"Error creating chunk nodes: {e}")
return None
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch.
Args:
summaries: List of MemorySummaryNode objects to add
connector: Neo4j connector instance
Returns:
List of created summary node ids or None if failed
"""
if not summaries:
print("No memory summary nodes to add")
return []
try:
flattened = []
for s in summaries:
flattened.append({
"id": s.id,
"name": s.name,
"group_id": s.group_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
"dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids,
"content": s.content,
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})
result = await connector.execute_query(
MEMORY_SUMMARY_NODE_SAVE,
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
return created_ids
except Exception:
return None

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Neo4j仓储基类模块
本模块提供Neo4j仓储的基类实现封装了通用的Neo4j节点操作。
Classes:
BaseNeo4jRepository: Neo4j仓储基类实现通用的CRUD操作
"""
from typing import List, Optional, Dict, Any, TypeVar
from app.repositories.base_repository import BaseRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
T = TypeVar('T')
class BaseNeo4jRepository(BaseRepository[T]):
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
这个基类封装了Neo4j节点的通用CRUD操作子类只需要实现
特定的映射逻辑和业务查询方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签(如"Dialogue", "Statement"等)
Type Parameters:
T: 实体类型通常是Pydantic模型
"""
def __init__(self, connector: Neo4jConnector, node_label: str):
"""初始化Neo4j仓储
Args:
connector: Neo4j连接器实例
node_label: 节点标签用于Cypher查询
"""
self.connector = connector
self.node_label = node_label
async def create(self, entity: T) -> T:
"""创建节点
将实体对象转换为Neo4j节点并保存到数据库。
Args:
entity: 要创建的实体对象
Returns:
T: 创建后的实体对象
Example:
>>> dialog = DialogueNode(id="123", name="对话1", ...)
>>> created = await repository.create(dialog)
"""
query = f"""
CREATE (n:{self.node_label} $props)
RETURN n
"""
result = await self.connector.execute_query(
query,
props=entity.model_dump()
)
return entity
async def get_by_id(self, entity_id: str) -> Optional[T]:
"""根据ID获取节点
Args:
entity_id: 节点ID
Returns:
Optional[T]: 找到的实体对象如果不存在则返回None
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
RETURN n
"""
result = await self.connector.execute_query(query, id=entity_id)
if result:
return self._map_to_entity(result[0])
return None
async def update(self, entity: T) -> T:
"""更新节点
更新现有节点的属性。使用SET +=语法合并属性。
Args:
entity: 要更新的实体对象必须包含id字段
Returns:
T: 更新后的实体对象
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
SET n += $props
RETURN n
"""
await self.connector.execute_query(
query,
id=entity.id,
props=entity.model_dump()
)
return entity
async def delete(self, entity_id: str) -> bool:
"""删除节点
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
Args:
entity_id: 要删除的节点ID
Returns:
bool: 删除成功返回True否则返回False
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
DETACH DELETE n
RETURN count(n) as deleted
"""
result = await self.connector.execute_query(query, id=entity_id)
return result[0]['deleted'] > 0 if result else False
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
"""查询节点
根据过滤条件查询节点列表。
Args:
filters: 查询条件字典,键为属性名,值为期望的值
limit: 返回结果的最大数量
Returns:
List[T]: 符合条件的实体列表
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""
# 构建查询条件
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
limit=limit,
**filters
)
return [self._map_to_entity(r) for r in results]
def _map_to_entity(self, node_data: Dict) -> T:
"""将节点数据映射为实体对象
这是一个抽象方法,子类必须实现具体的映射逻辑。
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
T: 映射后的实体对象
Raises:
NotImplementedError: 如果子类未实现此方法
"""
raise NotImplementedError("Subclasses must implement _map_to_entity method")

View File

@@ -0,0 +1,332 @@
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Full-Text Indexes (for keyword search)")
print("=" * 70)
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: statementsFulltext")
# # 创建 Dialogues 索引
# await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
# """)
# 创建 Entities 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: entitiesFulltext")
# 创建 Chunks 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: chunksFulltext")
# 创建 MemorySummary 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: summariesFulltext")
print("\nFull-text indexes created successfully with BM25 support.")
except Exception as e:
print(f"✗ Error creating full-text indexes: {e}")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
Vector indexes provide 10-100x faster similarity search compared to manual cosine calculation.
This is critical for performance - reduces embedding search from ~1.4s to ~0.05-0.2s!
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Vector Indexes (for embedding search)")
print("=" * 70)
print("Note: Adjust vector.dimensions if using different embedding model")
print(" Current setting: 1024 dimensions (for bge-m3)")
print()
# Statement embedding index
await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
FOR (s:Statement)
ON s.statement_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: statement_embedding_index")
# Chunk embedding index
await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
FOR (c:Chunk)
ON c.chunk_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: chunk_embedding_index")
# Entity name embedding index
await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
FOR (e:ExtractedEntity)
ON e.name_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: entity_embedding_index")
# Memory summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
FOR (m:MemorySummary)
ON m.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: summary_embedding_index")
# Dialogue embedding index (optional)
await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
FOR (d:Dialogue)
ON d.dialog_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: dialogue_embedding_index")
print("\nVector indexes created successfully!")
print("\nExpected performance improvement:")
print(" Before: ~1.4s for embedding search")
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
except Exception as e:
print(f"✗ Error creating vector indexes: {e}")
finally:
await connector.close()
async def create_config_id_indexes():
"""Create indexes on config_id fields for improved query performance.
These indexes enable fast filtering of nodes by configuration ID,
which is essential for configuration isolation and multi-tenant scenarios.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Config ID Indexes")
print("=" * 70)
# Dialogue.config_id index
await connector.execute_query("""
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
FOR (d:Dialogue) ON (d.config_id)
""")
print("✓ Created: dialogue_config_id_index")
# Statement.config_id index
await connector.execute_query("""
CREATE INDEX statement_config_id_index IF NOT EXISTS
FOR (s:Statement) ON (s.config_id)
""")
print("✓ Created: statement_config_id_index")
# ExtractedEntity.config_id index
await connector.execute_query("""
CREATE INDEX entity_config_id_index IF NOT EXISTS
FOR (e:ExtractedEntity) ON (e.config_id)
""")
print("✓ Created: entity_config_id_index")
# MemorySummary.config_id index
await connector.execute_query("""
CREATE INDEX summary_config_id_index IF NOT EXISTS
FOR (m:MemorySummary) ON (m.config_id)
""")
print("✓ Created: summary_config_id_index")
print("\nConfig ID indexes created successfully!")
print("These indexes enable fast filtering by configuration ID.")
except Exception as e:
print(f"✗ Error creating config_id indexes: {e}")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Unique Constraints")
print("=" * 70)
# Dialogue.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT dialog_id_unique IF NOT EXISTS
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
"""
)
print("✓ Created: dialog_id_unique")
# Statement.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT statement_id_unique IF NOT EXISTS
FOR (s:Statement) REQUIRE s.id IS UNIQUE
"""
)
print("✓ Created: statement_id_unique")
# Chunk.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT chunk_id_unique IF NOT EXISTS
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
"""
)
print("✓ Created: chunk_id_unique")
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
except Exception as e:
print(f"✗ Error creating unique constraints: {e}")
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
print("\n" + "=" * 70)
print("Neo4j Index & Constraint Setup")
print("=" * 70)
print("This will create:")
print(" 1. Full-text indexes (for keyword/BM25 search)")
print(" 2. Vector indexes (for embedding similarity search)")
print(" 3. Config ID indexes (for configuration isolation)")
print(" 4. Unique constraints (for data integrity)")
print("=" * 70)
await create_fulltext_indexes()
await create_vector_indexes()
await create_config_id_indexes()
await create_unique_constraints()
print("\n" + "=" * 70)
print("✓ All indexes and constraints created successfully!")
print("=" * 70)
print("\nTo verify, run in Neo4j Browser:")
print(" SHOW INDEXES")
print(" SHOW CONSTRAINTS")
print()
async def check_indexes():
"""Check what indexes currently exist."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Checking Existing Indexes")
print("=" * 70)
query = "SHOW INDEXES"
result = await connector.execute_query(query)
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
for idx in fulltext_indexes:
print(f"{idx.get('name')}")
print(f"\nVector indexes: {len(vector_indexes)}")
for idx in vector_indexes:
print(f"{idx.get('name')}")
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
for idx in range_indexes:
print(f"{idx.get('name')}")
if not vector_indexes:
print("\n⚠️ WARNING: No vector indexes found!")
print(" Embedding search will be VERY SLOW (~1.4s)")
print(" Run: python create_indexes.py")
# Check for config_id indexes
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
if len(config_id_indexes) < 4:
print("\n⚠️ WARNING: Not all config_id indexes found!")
print(f" Expected 4, found {len(config_id_indexes)}")
print(" Run: python create_indexes.py config_id")
print("=" * 70)
finally:
await connector.close()
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
command = sys.argv[1]
if command == "check":
asyncio.run(check_indexes())
elif command == "fulltext":
asyncio.run(create_fulltext_indexes())
elif command == "vector":
asyncio.run(create_vector_indexes())
elif command == "config_id":
asyncio.run(create_config_id_indexes())
elif command == "constraints":
asyncio.run(create_unique_constraints())
else:
print(f"Unknown command: {command}")
print("\nUsage:")
print(" python create_indexes.py # Create all indexes")
print(" python create_indexes.py check # Check existing indexes")
print(" python create_indexes.py fulltext # Create only full-text indexes")
print(" python create_indexes.py vector # Create only vector indexes")
print(" python create_indexes.py config_id # Create only config_id indexes")
print(" python create_indexes.py constraints # Create only constraints")
else:
asyncio.run(create_all_indexes())

View File

@@ -0,0 +1,684 @@
DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at,
n.expired_at = dialogue.expired_at,
n.content = dialogue.content,
n.dialog_embedding = dialogue.dialog_embedding
RETURN n.id AS uuid
"""
STATEMENT_NODE_SAVE = """
UNWIND $statements AS statement
MERGE (s:Statement {id: statement.id})
SET s += {
id: statement.id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
chunk_id: statement.chunk_id,
run_id: statement.run_id,
created_at: statement.created_at,
expired_at: statement.expired_at,
stmt_type: statement.stmt_type,
temporal_info: statement.temporal_info,
relevence_info: statement.relevence_info,
statement: statement.statement,
valid_at: statement.valid_at,
invalid_at: statement.invalid_at,
statement_embedding: statement.statement_embedding
}
RETURN s.id AS uuid
"""
CHUNK_NODE_SAVE = """
UNWIND $chunks AS chunk
MERGE (c:Chunk {id: chunk.id})
SET c += {
id: chunk.id,
name: chunk.name,
group_id: chunk.group_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
run_id: chunk.run_id,
created_at: chunk.created_at,
expired_at: chunk.expired_at,
dialog_id: chunk.dialog_id,
content: chunk.content,
chunk_embedding: chunk.chunk_embedding,
sequence_number: chunk.sequence_number,
start_index: chunk.start_index,
end_index: chunk.end_index
}
RETURN c.id AS uuid
"""
# bug修改点
EXTRACTED_ENTITY_NODE_SAVE = """
// Upsert entity nodes safely: preserve existing non-empty fields when incoming is empty
UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
THEN entity.created_at ELSE e.created_at END,
e.expired_at = CASE
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
THEN entity.expired_at ELSE e.expired_at END,
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
e.description = CASE
WHEN entity.description IS NOT NULL AND entity.description <> ''
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
THEN entity.description ELSE e.description END,
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
e.aliases = CASE
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
ELSE e.aliases END,
e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END,
e.fact_summary = CASE
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE
WHEN e.connect_strength = 'strong' AND entity.connect_strength = 'weak' THEN 'both'
WHEN e.connect_strength = 'weak' AND entity.connect_strength = 'strong' THEN 'both'
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
ELSE e.connect_strength
END
END
RETURN e.id AS uuid
"""
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
r.statement_id = rel.statement_id,
r.value = rel.value,
r.statement = rel.statement,
r.valid_at = rel.valid_at,
r.invalid_at = rel.invalid_at,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id,
r.group_id = rel.group_id
RETURN elementId(r) AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
# 保存弱关系实体,设置 e.is_weak = true不维护 e.relations 聚合字段
WEAK_ENTITY_NODE_SAVE = """
UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
group_id: entity.group_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
dialog_id: entity.dialog_id
}
// Independent weak flag仅标记弱关系不再维护 relations 聚合字段
SET e.is_weak = true
RETURN e.id AS id
"""
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true不维护 e.relations 字段
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
DIALOGUE_STATEMENT_EDGE_SAVE = """
UNWIND $dialogue_statement_edges AS edge
// 支持按 uuid 或 ref_id 连接到 Dialogue避免因来源 ID 不一致而断链
MATCH (dialogue:Dialogue)
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
MATCH (statement:Statement {id: edge.target})
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.group_id = edge.group_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
CHUNK_STATEMENT_EDGE_SAVE = """
UNWIND $chunk_statement_edges AS edge
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id,
e.run_id = edge.run_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.id AS uuid
"""
STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
// Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id,
r.run_id = rel.run_id,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.connect_strength = rel.connect_strength
RETURN elementId(r) AS uuid
"""
ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.entity_type AS entity_type,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Statement.statement_embedding
STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
# 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
e.statement_id AS statement_id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
e.fact_summary AS fact_summary,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# # 同组group_id下按name contains召回补充
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND toLower(e.name) CONTAINS toLower(row.name)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id)
AND d.id = $dialog_id
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at
ORDER BY d.created_at DESC
LIMIT $limit
"""
SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id)
AND c.id = $chunk_id
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.created_at AS created_at,
c.expired_at AS expired_at,
c.sequence_number AS sequence_number
ORDER BY c.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
collect(DISTINCT s.id) AS statement_ids
ORDER BY datetime(s.created_at) DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY s.created_at DESC, score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# # 同组group_id下按name contains召回补充
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND toLower(e.name) CONTAINS toLower(row.name)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id})
SET n.invalid_at = $new_invalid_at
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
score
ORDER BY score DESC
LIMIT $limit
"""
MEMORY_SUMMARY_NODE_SAVE = """
UNWIND $summaries AS summary
MERGE (m:MemorySummary {id: summary.id})
SET m += {
id: summary.id,
name: summary.name,
group_id: summary.group_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
run_id: summary.run_id,
created_at: summary.created_at,
expired_at: summary.expired_at,
dialog_id: summary.dialog_id,
chunk_ids: summary.chunk_ids,
content: summary.content,
summary_embedding: summary.summary_embedding,
config_id: summary.config_id
}
RETURN m.id AS uuid
"""
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE = """
UNWIND $edges AS e
MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.group_id = e.group_id,
r.run_id = e.run_id,
r.created_at = e.created_at,
r.expired_at = e.expired_at
RETURN elementId(r) AS uuid
"""

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
"""对话仓储模块
本模块提供对话节点的数据访问功能。
Classes:
DialogRepository: 对话仓储管理DialogueNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import DialogueNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"Dialogue"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化对话仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "Dialogue")
def _map_to_entity(self, node_data: Dict) -> DialogueNode:
"""将节点数据映射为对话实体
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
DialogueNode: 对话实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
Args:
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"user_id": user_id}, limit=limit)
async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]:
"""根据ref_id查询对话
ref_id是外部对话系统的引用ID通常是唯一的。
Args:
ref_id: 引用ID
Returns:
Optional[DialogueNode]: 找到的对话如果不存在则返回None
"""
results = await self.find({"ref_id": ref_id}, limit=1)
return results[0] if results else None
async def find_by_group_and_user(
self,
group_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
Args:
group_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表,按创建时间倒序排列
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
days=days,
limit=limit
)
return [self._map_to_entity(r) for r in results]
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id查询对话
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"config_id": config_id, "group_id": group_id},
limit=limit
)

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""实体仓储模块
本模块提供实体节点的数据访问功能。
Classes:
EntityRepository: 实体仓储管理ExtractedEntityNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import ExtractedEntityNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
"""实体仓储
管理实体节点的创建、查询、更新和删除操作。
提供按类型、名称、向量相似度等条件查询实体的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"ExtractedEntity"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化实体仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "ExtractedEntity")
def _map_to_entity(self, node_data: Dict) -> ExtractedEntityNode:
"""将节点数据映射为实体对象
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
ExtractedEntityNode: 实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return ExtractedEntityNode(**n)
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
"""根据实体类型查询
Args:
entity_type: 实体类型(如"Person", "Organization"等)
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"entity_type": entity_type}, limit=limit)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
"""根据group_id查询实体
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def find_by_name(
self,
name: str,
group_id: Optional[str] = None,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""根据名称查询实体
支持模糊匹配CONTAINS
Args:
name: 实体名称
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
where_clause = "n.name CONTAINS $name"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
RETURN n
LIMIT $limit
"""
params = {"name": name, "limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_related_entities(
self,
entity_id: str,
relation_type: Optional[str] = None,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""查询相关实体
查询与指定实体有关系的其他实体。
Args:
entity_id: 实体ID
relation_type: 可选的关系类型过滤
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 相关实体列表
"""
if relation_type:
query = """
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
RETURN e2 as n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
entity_id=entity_id,
relation_type=relation_type,
limit=limit
)
else:
query = """
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
RETURN e2 as n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
entity_id=entity_id,
limit=limit
)
return [self._map_to_entity(r) for r in results]
async def search_by_embedding(
self,
embedding: List[float],
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索实体
使用余弦相似度计算查询向量与实体名称向量的相似度。
Args:
embedding: 查询向量
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含实体和相似度分数的字典列表
每个字典包含: entity (ExtractedEntityNode), score (float)
"""
where_clause = "n.name_embedding IS NOT NULL"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [
{
"entity": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
"""根据陈述句ID查询实体
查询从指定陈述句中提取的所有实体。
Args:
statement_id: 陈述句ID
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"statement_id": statement_id})
async def find_strong_entities(
self,
group_id: str,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""查询强连接的实体
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 强连接的实体列表
"""
return await self.find(
{"group_id": group_id, "connect_strength": "Strong"},
limit=limit
)
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
"""统计各类型实体的数量
Args:
group_id: 组ID
Returns:
Dict[str, int]: 实体类型到数量的映射
"""
query = """
MATCH (n:ExtractedEntity {group_id: $group_id})
RETURN n.entity_type as entity_type, count(n) as count
ORDER BY count DESC
"""
results = await self.connector.execute_query(query, group_id=group_id)
return {r["entity_type"]: r["count"] for r in results}
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""根据config_id查询实体
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def search_by_embedding_with_config(
self,
embedding: List[float],
config_id: Optional[str] = None,
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索实体,可选择按config_id过滤
使用余弦相似度计算查询向量与实体名称向量的相似度。
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
Args:
embedding: 查询向量
config_id: 可选的配置ID过滤
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含实体和相似度分数的字典列表
每个字典包含: entity (ExtractedEntityNode), score (float)
"""
# 构建查询条件
where_clauses = ["n.name_embedding IS NOT NULL"]
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if config_id:
where_clauses.append("n.config_id = $config_id")
params["config_id"] = config_id
if group_id:
where_clauses.append("n.group_id = $group_id")
params["group_id"] = group_id
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [
{
"entity": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]

View File

@@ -0,0 +1,216 @@
from typing import List
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes
from app.repositories.neo4j.cypher_queries import (
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
CHUNK_STATEMENT_EDGE_SAVE,
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
)
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
ExtractedEntityNode,
EntityEntityEdge,
)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector
):
"""Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes]
all_relationships = []
for edge in entity_entity_edges:
relationship = {
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'group_id': edge.group_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
}
all_relationships.append(relationship)
# Save entities
if all_entities:
entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities)
if entity_uuids:
print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
else:
print("Failed to save entity nodes to Neo4j")
else:
print("No entity nodes to save")
# Create relationships
if all_relationships:
relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships)
if relationship_uuids:
print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j")
else:
print("Failed to save entity relationships to Neo4j")
else:
print("No entity relationships to save")
async def save_chunk_nodes(
chunk_nodes: List[ChunkNode],
connector: Neo4jConnector
):
"""Save chunk nodes using graph models"""
if not chunk_nodes:
print("No chunk nodes to save")
return
chunk_uuids = await add_chunk_nodes(chunk_nodes, connector)
if chunk_uuids:
print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
else:
print("Failed to save chunk nodes to Neo4j")
async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector
):
"""Save statement-chunk edges using graph models"""
if not statement_chunk_edges:
return
all_sc_edges = []
for edge in statement_chunk_edges:
all_sc_edges.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
})
try:
await connector.execute_query(
CHUNK_STATEMENT_EDGE_SAVE,
chunk_statement_edges=all_sc_edges
)
except Exception:
pass
async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
):
"""Save statement-entity edges using graph models"""
if not statement_entity_edges:
print("No statement-entity edges to save")
return
all_se_edges = []
for edge in statement_entity_edges:
edge_data = {
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
}
all_se_edges.append(edge_data)
if all_se_edges:
try:
await connector.execute_query(
STATEMENT_ENTITY_EDGE_SAVE,
relationships=all_se_edges
)
except Exception:
pass
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
Args:
dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save
entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save
connector: Neo4j connector instance
Returns:
bool: True if successful, False otherwise
"""
try:
# Save all dialogue nodes in batch
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
if dialogue_uuids:
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector)
# Save all statement nodes in batch
if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector)
if statement_uuids:
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("Failed to save statement nodes to Neo4j")
return False
else:
print("No statement nodes to save")
# Save entities and relationships
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
print("Successfully saved entities and relationships to Neo4j")
# Save new edges
await save_statement_chunk_edges(statement_chunk_edges, connector)
await save_statement_entity_edges(statement_entity_edges, connector)
return True
except Exception as e:
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False

View File

@@ -0,0 +1,584 @@
from typing import Any, Dict, List, Optional
import asyncio
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import (
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_ENTITIES_BY_NAME,
SEARCH_CHUNKS_BY_CONTENT,
STATEMENT_EMBEDDING_SEARCH,
CHUNK_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_STATEMENTS_BY_TEMPORAL,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_VALID_AT,
SEARCH_STATEMENTS_G_CREATED_AT,
SEARCH_STATEMENTS_L_CREATED_AT,
SEARCH_STATEMENTS_G_VALID_AT,
SEARCH_STATEMENTS_L_VALID_AT,
)
async def search_graph(
connector: Neo4jConnector,
q: str,
group_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
- Statements: matches s.statement CONTAINS q
- Entities: matches e.name CONTAINS q
- Chunks: matches s.content CONTAINS q (from Statement nodes)
- Summaries: matches ms.content CONTAINS q
Args:
connector: Neo4j connector
q: Query text
group_id: Optional group filter
limit: Max results per category
include: List of categories to search (default: all)
Returns:
Dictionary with search results per category
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("summaries")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary
results = {}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
results[key] = []
else:
results[key] = result
return results
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
group_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
- Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher
- Filters by group_id if provided
- Returns up to 'limit' per included type
"""
import time
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
embedding = embeddings[0]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
# Statements (embedding)
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("statements")
# Chunks (embedding)
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("chunks")
# Entities
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("entities")
# Memory summaries
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("summaries")
# Execute all queries in parallel
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
"statements": [],
"chunks": [],
"entities": [],
"summaries": [],
}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
results[key] = []
else:
results[key] = result
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
group_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
返回incoming_id -> [db_entity_props...]
"""
if not entities:
return {}
sem = asyncio.Semaphore(max_concurrency)
async def _query_by_name(incoming: Dict[str, Any]) -> tuple[str, List[Dict[str, Any]]]:
async with sem:
inc_id = incoming.get("id") or "__unknown__"
name = (incoming.get("name") or "").strip()
if not name:
return inc_id, []
try:
# 全文索引按名称检索(包含 CONTAINS 语义)
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name,
group_id=group_id,
limit=100,
)
except Exception:
rows = []
# 可选本地类型过滤(若输入实体提供类型)
typ = incoming.get("entity_type")
if typ:
try:
rows = [r for r in rows if (r.get("entity_type") == typ)]
except Exception:
pass
# 注入 incoming_id 以保持兼容下游合并逻辑
for r in rows:
r["incoming_id"] = inc_id
# 简单的降级:若为空且允许 fallback可按小写名再次查询
if use_contains_fallback and not rows and name:
try:
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name.lower(),
group_id=group_id,
limit=100,
)
for r in rows:
r["incoming_id"] = inc_id
except Exception:
pass
return inc_id, rows
tasks = [_query_by_name(e) for e in entities]
results = await asyncio.gather(*tasks, return_exceptions=True)
merged: Dict[str, List[Dict[str, Any]]] = {}
for res in results:
if isinstance(res, Exception):
# 静默跳过单条失败
continue
inc_id, rows = res
inc_id = inc_id or "__unknown__"
merged.setdefault(inc_id, [])
existing_ids = {x.get("id") for x in merged[inc_id]}
for rec in rows:
if rec.get("id") not in existing_ids:
merged[inc_id].append(rec)
return merged
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
- Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_temporal(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
group_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
- Matches dialogues with dialog_id
- Optionally filters by group_id
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id,
dialog_id=dialog_id,
limit=limit,
)
return {"dialogues": dialogues}
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
group_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id,
chunk_id=chunk_id,
limit=limit,
)
return {"chunks": chunks}
async def search_graph_by_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_g_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_g_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_l_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_l_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""Neo4j连接器模块
本模块提供Neo4j图数据库的连接和查询功能。
从 app/core/memory/src/database/neo4j_connector.py 迁移而来。
Classes:
Neo4jConnector: Neo4j数据库连接器提供异步查询接口
"""
import os
from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth
from app.core.config import settings
class Neo4jConnector:
"""Neo4j数据库连接器
提供与Neo4j图数据库的连接和查询功能。
使用异步驱动程序以支持高并发操作。
Attributes:
driver: Neo4j异步驱动程序实例
Methods:
close: 关闭数据库连接
execute_query: 执行Cypher查询
delete_group: 删除指定组的所有数据
"""
def __init__(self):
"""初始化Neo4j连接器
从配置文件和环境变量中读取连接信息。
Raises:
RuntimeError: 如果NEO4J_PASSWORD环境变量未设置
"""
# 从全局配置和环境变量获取 Neo4j 配置
uri = settings.NEO4J_URI
username = settings.NEO4J_USERNAME
password = settings.NEO4J_PASSWORD
if not password:
raise RuntimeError(
"NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running."
)
self.driver = AsyncGraphDatabase.driver(
uri,
auth=basic_auth(username, password)
)
async def close(self):
"""关闭数据库连接
释放数据库连接资源。应在应用程序关闭时调用。
"""
await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
query: Cypher查询语句
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example:
>>> connector = Neo4jConnector()
>>> results = await connector.execute_query(
... "MATCH (n:Person {name: $name}) RETURN n",
... name="Alice"
... )
"""
result = await self.driver.execute_query(
query,
database="neo4j",
**kwargs
)
records, summary, keys = result
return [record.data() for record in records]
async def delete_group(self, group_id: str):
"""删除指定组的所有数据
删除所有属于指定group_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
group_id: 要删除的组ID
Example:
>>> connector = Neo4jConnector()
>>> await connector.delete_group("group_123")
Group group_123 deleted.
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
database="neo4j",
group_id=group_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
database="neo4j",
group_id=group_id
)
print(f"Group {group_id} deleted.")

View File

@@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
"""陈述句仓储模块
本模块提供陈述句节点的数据访问功能。
Classes:
StatementRepository: 陈述句仓储管理StatementNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import StatementNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.data.ontology import TemporalInfo
class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"Statement"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化陈述句仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "Statement")
def _map_to_entity(self, node_data: Dict) -> StatementNode:
"""将节点数据映射为陈述句实体
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
StatementNode: 陈述句实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
if n.get('valid_at') and isinstance(n['valid_at'], str):
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
# 处理temporal_info字段
if isinstance(n.get('temporal_info'), dict):
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
elif not n.get('temporal_info'):
# 如果没有temporal_info创建一个默认的
n['temporal_info'] = TemporalInfo()
return StatementNode(**n)
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
"""根据chunk_id查询陈述句
Args:
chunk_id: 分块ID
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"chunk_id": chunk_id})
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
"""根据group_id查询陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def search_by_embedding(
self,
embedding: List[float],
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句
使用余弦相似度计算查询向量与陈述句向量的相似度。
Args:
embedding: 查询向量
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clause = "n.statement_embedding IS NOT NULL"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]
async def search_by_keyword(
self,
keyword: str,
group_id: Optional[str] = None,
limit: int = 50
) -> List[StatementNode]:
"""基于关键词搜索陈述句
Args:
keyword: 搜索关键词
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clause = "n.statement CONTAINS $keyword"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
RETURN n
LIMIT $limit
"""
params = {"keyword": keyword, "limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_by_temporal_range(
self,
group_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100
) -> List[StatementNode]:
"""根据时间范围查询陈述句
查询在指定时间范围内有效的陈述句。
Args:
group_id: 组ID
start_date: 开始日期(可选)
end_date: 结束日期(可选)
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clauses = ["n.group_id = $group_id"]
params = {"group_id": group_id, "limit": limit}
if start_date:
where_clauses.append("n.valid_at >= $start_date")
params["start_date"] = start_date.isoformat()
if end_date:
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
params["end_date"] = end_date.isoformat()
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_strong_statements(
self,
group_id: str,
limit: int = 100
) -> List[StatementNode]:
"""查询强连接的陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 强连接的陈述句列表
"""
return await self.find(
{"group_id": group_id, "connect_strength": "Strong"},
limit=limit
)
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[StatementNode]:
"""根据config_id查询陈述句
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def search_by_embedding_with_config(
self,
embedding: List[float],
config_id: Optional[str] = None,
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句,可选择按config_id过滤
使用余弦相似度计算查询向量与陈述句向量的相似度。
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
Args:
embedding: 查询向量
config_id: 可选的配置ID过滤
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clauses = ["n.statement_embedding IS NOT NULL"]
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if config_id:
where_clauses.append("n.config_id = $config_id")
params["config_id"] = config_id
if group_id:
where_clauses.append("n.group_id = $group_id")
params["group_id"] = group_id
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]