Initial commit
This commit is contained in:
32
app/repositories/neo4j/__init__.py
Normal file
32
app/repositories/neo4j/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储模块
|
||||
|
||||
本模块包含Neo4j图数据库的仓储实现,用于管理知识图谱的节点和边。
|
||||
|
||||
Modules:
|
||||
neo4j_connector: Neo4j数据库连接器
|
||||
base_neo4j_repository: Neo4j仓储基类
|
||||
dialog_repository: 对话仓储
|
||||
statement_repository: 陈述句仓储
|
||||
entity_repository: 实体仓储
|
||||
cypher_queries: Cypher查询语句
|
||||
graph_search: 图搜索功能
|
||||
graph_saver: 图数据保存功能
|
||||
add_nodes: 添加节点功能
|
||||
add_edges: 添加边功能
|
||||
create_indexes: 创建索引功能
|
||||
"""
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.repositories.neo4j.dialog_repository import DialogRepository
|
||||
from app.repositories.neo4j.statement_repository import StatementRepository
|
||||
from app.repositories.neo4j.entity_repository import EntityRepository
|
||||
|
||||
__all__ = [
|
||||
'Neo4jConnector',
|
||||
'BaseNeo4jRepository',
|
||||
'DialogRepository',
|
||||
'StatementRepository',
|
||||
'EntityRepository',
|
||||
]
|
||||
102
app/repositories/neo4j/add_edges.py
Normal file
102
app/repositories/neo4j/add_edges.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List, Optional
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE
|
||||
from app.core.memory.models.message_models import Chunk
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
|
||||
async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add edges between chunk nodes and their statement nodes in Neo4j.
|
||||
|
||||
Args:
|
||||
chunks: List of Chunk objects containing the statements
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunks provided to create edges")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Build edges deterministically per (chunk, statement) pair
|
||||
edges: List[dict] = []
|
||||
for chunk in chunks:
|
||||
for stmt in getattr(chunk, "statements", []) or []:
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
|
||||
edge = {
|
||||
"id": stable_edge_id,
|
||||
"source": chunk.id,
|
||||
"target": stmt.id,
|
||||
"group_id": getattr(stmt, 'group_id', None),
|
||||
"user_id":getattr(stmt, 'user_id', None),
|
||||
"apply_id": getattr(stmt, 'apply_id', None),
|
||||
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
||||
"created_at": getattr(stmt, 'created_at', None),
|
||||
"expired_at": getattr(stmt, 'expired_at', None),
|
||||
# "created_at": getattr(statement, 'created_at', None),
|
||||
# "expired_at": None # Set to None or appropriate default
|
||||
}
|
||||
edges.append(edge)
|
||||
|
||||
if not edges:
|
||||
print("No statements found in chunks to create edges")
|
||||
return []
|
||||
|
||||
# Execute the query to create edges
|
||||
result = await connector.execute_query(
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
chunk_statement_edges=edges
|
||||
)
|
||||
created_uuids = [record.get("uuid") for record in result] if result else []
|
||||
print(f"Successfully created {len(created_uuids)} chunk-statement edges")
|
||||
return created_uuids
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk-statement edges: {e}")
|
||||
return None
|
||||
|
||||
async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Create edges from MemorySummary to Statements via their chunk_ids.
|
||||
|
||||
For each summary and each chunk_id in it, this links the summary to all statements
|
||||
contained in that chunk using DERIVED_FROM_STATEMENT. This supports queries like
|
||||
summary -> statement -> entity with minimal hops.
|
||||
|
||||
Args:
|
||||
summaries: List of MemorySummaryNode objects
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge elementIds or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
return []
|
||||
|
||||
try:
|
||||
edges: List[dict] = []
|
||||
for s in summaries:
|
||||
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
||||
edges.append({
|
||||
"summary_id": s.id,
|
||||
"chunk_id": chunk_id,
|
||||
"group_id": s.group_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
})
|
||||
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||
edges=edges
|
||||
)
|
||||
created = [record.get("uuid") for record in result] if result else []
|
||||
return created
|
||||
except Exception:
|
||||
return None
|
||||
215
app/repositories/neo4j/add_nodes.py
Normal file
215
app/repositories/neo4j/add_nodes.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
|
||||
print(f"All group_id: {group_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
Args:
|
||||
dialogues: List of DialogueNode objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not dialogues:
|
||||
print("No dialogues to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Flatten DialogueNode objects to match Cypher expected fields
|
||||
flattened_dialogues = []
|
||||
for dialogue in dialogues:
|
||||
flattened_dialogues.append({
|
||||
"id": dialogue.id,
|
||||
"group_id": dialogue.group_id,
|
||||
"user_id": dialogue.user_id,
|
||||
"apply_id": dialogue.apply_id,
|
||||
"run_id": dialogue.run_id,
|
||||
"ref_id": dialogue.ref_id,
|
||||
"name": dialogue.name,
|
||||
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
|
||||
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
|
||||
"content": dialogue.content,
|
||||
"dialog_embedding": dialogue.dialog_embedding
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
DIALOGUE_NODE_SAVE,
|
||||
dialogues=flattened_dialogues
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add statement nodes to Neo4j database.
|
||||
|
||||
Args:
|
||||
statements: List of StatementNode objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not statements:
|
||||
print("No statements to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Flatten StatementNode objects to only include primitive types
|
||||
flattened_statements = []
|
||||
for statement in statements:
|
||||
flattened_statement = {
|
||||
"id": statement.id,
|
||||
"name": statement.name,
|
||||
"group_id": statement.group_id,
|
||||
"user_id": statement.user_id,
|
||||
"apply_id": statement.apply_id,
|
||||
"run_id": statement.run_id,
|
||||
"chunk_id": statement.chunk_id,
|
||||
# "created_at": statement.created_at.isoformat(),
|
||||
"created_at": statement.created_at.isoformat() if statement.created_at else None,
|
||||
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
|
||||
"stmt_type": statement.stmt_type,
|
||||
"temporal_info": statement.temporal_info.value,
|
||||
"statement": statement.statement,
|
||||
"connect_strength": statement.connect_strength,
|
||||
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
|
||||
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
|
||||
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
|
||||
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
|
||||
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
|
||||
# "triplet_extraction_info": json.dumps({
|
||||
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
|
||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
result = await connector.execute_query(
|
||||
STATEMENT_NODE_SAVE,
|
||||
statements=flattened_statements
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
chunks: List of ChunkNode objects to add
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created chunk UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunk nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convert chunk nodes to dictionaries for the query
|
||||
flattened_chunks = []
|
||||
for chunk in chunks:
|
||||
# Flatten metadata properties to avoid Neo4j Map type issues
|
||||
metadata = chunk.metadata if chunk.metadata else {}
|
||||
flattened_chunk = {
|
||||
"id": chunk.id,
|
||||
"name": chunk.name,
|
||||
"group_id": chunk.group_id,
|
||||
"user_id": chunk.user_id,
|
||||
"apply_id": chunk.apply_id,
|
||||
"run_id": chunk.run_id,
|
||||
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
||||
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
||||
"dialog_id": chunk.dialog_id,
|
||||
"content": chunk.content,
|
||||
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
||||
"sequence_number": chunk.sequence_number,
|
||||
"start_index": metadata.get("start_index"),
|
||||
"end_index": metadata.get("end_index")
|
||||
}
|
||||
flattened_chunks.append(flattened_chunk)
|
||||
|
||||
result = await connector.execute_query(
|
||||
CHUNK_NODE_SAVE,
|
||||
chunks=flattened_chunks
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
summaries: List of MemorySummaryNode objects to add
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created summary node ids or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
print("No memory summary nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
flattened = []
|
||||
for s in summaries:
|
||||
flattened.append({
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"group_id": s.group_id,
|
||||
"user_id": s.user_id,
|
||||
"apply_id": s.apply_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
"dialog_id": s.dialog_id,
|
||||
"chunk_ids": s.chunk_ids,
|
||||
"content": s.content,
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_NODE_SAVE,
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
return created_ids
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
175
app/repositories/neo4j/base_neo4j_repository.py
Normal file
175
app/repositories/neo4j/base_neo4j_repository.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储基类模块
|
||||
|
||||
本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。
|
||||
|
||||
Classes:
|
||||
BaseNeo4jRepository: Neo4j仓储基类,实现通用的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, TypeVar
|
||||
from app.repositories.base_repository import BaseRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseNeo4jRepository(BaseRepository[T]):
|
||||
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
|
||||
|
||||
这个基类封装了Neo4j节点的通用CRUD操作,子类只需要实现
|
||||
特定的映射逻辑和业务查询方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签(如"Dialogue", "Statement"等)
|
||||
|
||||
Type Parameters:
|
||||
T: 实体类型,通常是Pydantic模型
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector, node_label: str):
|
||||
"""初始化Neo4j仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,用于Cypher查询
|
||||
"""
|
||||
self.connector = connector
|
||||
self.node_label = node_label
|
||||
|
||||
async def create(self, entity: T) -> T:
|
||||
"""创建节点
|
||||
|
||||
将实体对象转换为Neo4j节点并保存到数据库。
|
||||
|
||||
Args:
|
||||
entity: 要创建的实体对象
|
||||
|
||||
Returns:
|
||||
T: 创建后的实体对象
|
||||
|
||||
Example:
|
||||
>>> dialog = DialogueNode(id="123", name="对话1", ...)
|
||||
>>> created = await repository.create(dialog)
|
||||
"""
|
||||
query = f"""
|
||||
CREATE (n:{self.node_label} $props)
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(
|
||||
query,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def get_by_id(self, entity_id: str) -> Optional[T]:
|
||||
"""根据ID获取节点
|
||||
|
||||
Args:
|
||||
entity_id: 节点ID
|
||||
|
||||
Returns:
|
||||
Optional[T]: 找到的实体对象,如果不存在则返回None
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
if result:
|
||||
return self._map_to_entity(result[0])
|
||||
return None
|
||||
|
||||
async def update(self, entity: T) -> T:
|
||||
"""更新节点
|
||||
|
||||
更新现有节点的属性。使用SET +=语法合并属性。
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体对象(必须包含id字段)
|
||||
|
||||
Returns:
|
||||
T: 更新后的实体对象
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
SET n += $props
|
||||
RETURN n
|
||||
"""
|
||||
await self.connector.execute_query(
|
||||
query,
|
||||
id=entity.id,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity_id: str) -> bool:
|
||||
"""删除节点
|
||||
|
||||
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
|
||||
|
||||
Args:
|
||||
entity_id: 要删除的节点ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
DETACH DELETE n
|
||||
RETURN count(n) as deleted
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
return result[0]['deleted'] > 0 if result else False
|
||||
|
||||
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
|
||||
"""查询节点
|
||||
|
||||
根据过滤条件查询节点列表。
|
||||
|
||||
Args:
|
||||
filters: 查询条件字典,键为属性名,值为期望的值
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[T]: 符合条件的实体列表
|
||||
|
||||
Example:
|
||||
>>> results = await repository.find(
|
||||
... {"group_id": "group_123", "user_id": "user_456"},
|
||||
... limit=50
|
||||
... )
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
|
||||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
limit=limit,
|
||||
**filters
|
||||
)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> T:
|
||||
"""将节点数据映射为实体对象
|
||||
|
||||
这是一个抽象方法,子类必须实现具体的映射逻辑。
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
T: 映射后的实体对象
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 如果子类未实现此方法
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _map_to_entity method")
|
||||
332
app/repositories/neo4j/create_indexes.py
Normal file
332
app/repositories/neo4j/create_indexes.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Full-Text Indexes (for keyword search)")
|
||||
print("=" * 70)
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: statementsFulltext")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
# """)
|
||||
|
||||
# 创建 Entities 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: entitiesFulltext")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: chunksFulltext")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: summariesFulltext")
|
||||
|
||||
print("\nFull-text indexes created successfully with BM25 support.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating full-text indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
Vector indexes provide 10-100x faster similarity search compared to manual cosine calculation.
|
||||
This is critical for performance - reduces embedding search from ~1.4s to ~0.05-0.2s!
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Vector Indexes (for embedding search)")
|
||||
print("=" * 70)
|
||||
print("Note: Adjust vector.dimensions if using different embedding model")
|
||||
print(" Current setting: 1024 dimensions (for bge-m3)")
|
||||
print()
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||
FOR (s:Statement)
|
||||
ON s.statement_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: statement_embedding_index")
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||
FOR (c:Chunk)
|
||||
ON c.chunk_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: chunk_embedding_index")
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity)
|
||||
ON e.name_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: entity_embedding_index")
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary)
|
||||
ON m.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: summary_embedding_index")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
FOR (d:Dialogue)
|
||||
ON d.dialog_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: dialogue_embedding_index")
|
||||
|
||||
print("\nVector indexes created successfully!")
|
||||
print("\nExpected performance improvement:")
|
||||
print(" Before: ~1.4s for embedding search")
|
||||
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating vector indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_config_id_indexes():
|
||||
"""Create indexes on config_id fields for improved query performance.
|
||||
|
||||
These indexes enable fast filtering of nodes by configuration ID,
|
||||
which is essential for configuration isolation and multi-tenant scenarios.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Config ID Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
|
||||
FOR (d:Dialogue) ON (d.config_id)
|
||||
""")
|
||||
print("✓ Created: dialogue_config_id_index")
|
||||
|
||||
# Statement.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX statement_config_id_index IF NOT EXISTS
|
||||
FOR (s:Statement) ON (s.config_id)
|
||||
""")
|
||||
print("✓ Created: statement_config_id_index")
|
||||
|
||||
# ExtractedEntity.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX entity_config_id_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity) ON (e.config_id)
|
||||
""")
|
||||
print("✓ Created: entity_config_id_index")
|
||||
|
||||
# MemorySummary.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX summary_config_id_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary) ON (m.config_id)
|
||||
""")
|
||||
print("✓ Created: summary_config_id_index")
|
||||
|
||||
print("\nConfig ID indexes created successfully!")
|
||||
print("These indexes enable fast filtering by configuration ID.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating config_id indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Unique Constraints")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT dialog_id_unique IF NOT EXISTS
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: dialog_id_unique")
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT statement_id_unique IF NOT EXISTS
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: statement_id_unique")
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT chunk_id_unique IF NOT EXISTS
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: chunk_id_unique")
|
||||
|
||||
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating unique constraints: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
print("\n" + "=" * 70)
|
||||
print("Neo4j Index & Constraint Setup")
|
||||
print("=" * 70)
|
||||
print("This will create:")
|
||||
print(" 1. Full-text indexes (for keyword/BM25 search)")
|
||||
print(" 2. Vector indexes (for embedding similarity search)")
|
||||
print(" 3. Config ID indexes (for configuration isolation)")
|
||||
print(" 4. Unique constraints (for data integrity)")
|
||||
print("=" * 70)
|
||||
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_config_id_indexes()
|
||||
await create_unique_constraints()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
print("=" * 70)
|
||||
print("\nTo verify, run in Neo4j Browser:")
|
||||
print(" SHOW INDEXES")
|
||||
print(" SHOW CONSTRAINTS")
|
||||
print()
|
||||
|
||||
|
||||
async def check_indexes():
|
||||
"""Check what indexes currently exist."""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Checking Existing Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
query = "SHOW INDEXES"
|
||||
result = await connector.execute_query(query)
|
||||
|
||||
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
|
||||
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
|
||||
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
|
||||
|
||||
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
|
||||
for idx in fulltext_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nVector indexes: {len(vector_indexes)}")
|
||||
for idx in vector_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
|
||||
for idx in range_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
if not vector_indexes:
|
||||
print("\n⚠️ WARNING: No vector indexes found!")
|
||||
print(" Embedding search will be VERY SLOW (~1.4s)")
|
||||
print(" Run: python create_indexes.py")
|
||||
|
||||
# Check for config_id indexes
|
||||
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
|
||||
if len(config_id_indexes) < 4:
|
||||
print("\n⚠️ WARNING: Not all config_id indexes found!")
|
||||
print(f" Expected 4, found {len(config_id_indexes)}")
|
||||
print(" Run: python create_indexes.py config_id")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
if command == "check":
|
||||
asyncio.run(check_indexes())
|
||||
elif command == "fulltext":
|
||||
asyncio.run(create_fulltext_indexes())
|
||||
elif command == "vector":
|
||||
asyncio.run(create_vector_indexes())
|
||||
elif command == "config_id":
|
||||
asyncio.run(create_config_id_indexes())
|
||||
elif command == "constraints":
|
||||
asyncio.run(create_unique_constraints())
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("\nUsage:")
|
||||
print(" python create_indexes.py # Create all indexes")
|
||||
print(" python create_indexes.py check # Check existing indexes")
|
||||
print(" python create_indexes.py fulltext # Create only full-text indexes")
|
||||
print(" python create_indexes.py vector # Create only vector indexes")
|
||||
print(" python create_indexes.py config_id # Create only config_id indexes")
|
||||
print(" python create_indexes.py constraints # Create only constraints")
|
||||
else:
|
||||
asyncio.run(create_all_indexes())
|
||||
|
||||
684
app/repositories/neo4j/cypher_queries.py
Normal file
684
app/repositories/neo4j/cypher_queries.py
Normal file
@@ -0,0 +1,684 @@
|
||||
|
||||
DIALOGUE_NODE_SAVE = """
|
||||
UNWIND $dialogues AS dialogue
|
||||
MERGE (n:Dialogue {id: dialogue.id})
|
||||
SET n.uuid = coalesce(n.uuid, dialogue.id),
|
||||
n.group_id = dialogue.group_id,
|
||||
n.user_id = dialogue.user_id,
|
||||
n.apply_id = dialogue.apply_id,
|
||||
n.run_id = dialogue.run_id,
|
||||
n.ref_id = dialogue.ref_id,
|
||||
n.created_at = dialogue.created_at,
|
||||
n.expired_at = dialogue.expired_at,
|
||||
n.content = dialogue.content,
|
||||
n.dialog_embedding = dialogue.dialog_embedding
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
STATEMENT_NODE_SAVE = """
|
||||
UNWIND $statements AS statement
|
||||
MERGE (s:Statement {id: statement.id})
|
||||
SET s += {
|
||||
id: statement.id,
|
||||
group_id: statement.group_id,
|
||||
user_id: statement.user_id,
|
||||
apply_id: statement.apply_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
run_id: statement.run_id,
|
||||
created_at: statement.created_at,
|
||||
expired_at: statement.expired_at,
|
||||
stmt_type: statement.stmt_type,
|
||||
temporal_info: statement.temporal_info,
|
||||
relevence_info: statement.relevence_info,
|
||||
statement: statement.statement,
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
|
||||
CHUNK_NODE_SAVE = """
|
||||
UNWIND $chunks AS chunk
|
||||
MERGE (c:Chunk {id: chunk.id})
|
||||
SET c += {
|
||||
id: chunk.id,
|
||||
name: chunk.name,
|
||||
group_id: chunk.group_id,
|
||||
user_id: chunk.user_id,
|
||||
apply_id: chunk.apply_id,
|
||||
run_id: chunk.run_id,
|
||||
created_at: chunk.created_at,
|
||||
expired_at: chunk.expired_at,
|
||||
dialog_id: chunk.dialog_id,
|
||||
content: chunk.content,
|
||||
chunk_embedding: chunk.chunk_embedding,
|
||||
sequence_number: chunk.sequence_number,
|
||||
start_index: chunk.start_index,
|
||||
end_index: chunk.end_index
|
||||
}
|
||||
RETURN c.id AS uuid
|
||||
"""
|
||||
# bug修改点
|
||||
|
||||
EXTRACTED_ENTITY_NODE_SAVE = """
|
||||
// Upsert entity nodes safely: preserve existing non-empty fields when incoming is empty
|
||||
UNWIND $entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id})
|
||||
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
|
||||
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
|
||||
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
|
||||
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
|
||||
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
|
||||
e.created_at = CASE
|
||||
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
||||
THEN entity.created_at ELSE e.created_at END,
|
||||
e.expired_at = CASE
|
||||
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
|
||||
THEN entity.expired_at ELSE e.expired_at END,
|
||||
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
||||
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
||||
e.description = CASE
|
||||
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
||||
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
||||
THEN entity.description ELSE e.description END,
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
|
||||
ELSE e.aliases END,
|
||||
e.name_embedding = CASE
|
||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||
ELSE e.name_embedding END,
|
||||
e.fact_summary = CASE
|
||||
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||
THEN entity.fact_summary ELSE e.fact_summary END,
|
||||
e.connect_strength = CASE
|
||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||
ELSE CASE
|
||||
WHEN e.connect_strength = 'strong' AND entity.connect_strength = 'weak' THEN 'both'
|
||||
WHEN e.connect_strength = 'weak' AND entity.connect_strength = 'strong' THEN 'both'
|
||||
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
|
||||
ELSE e.connect_strength
|
||||
END
|
||||
END
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
|
||||
ENTITY_RELATIONSHIP_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Match entities by stable id within group, do not constrain by run_id
|
||||
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
|
||||
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
|
||||
// Avoid duplicate edges across runs for the same endpoints
|
||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||
SET r.predicate = rel.predicate,
|
||||
r.statement_id = rel.statement_id,
|
||||
r.value = rel.value,
|
||||
r.statement = rel.statement,
|
||||
r.valid_at = rel.valid_at,
|
||||
r.invalid_at = rel.invalid_at,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
r.run_id = rel.run_id,
|
||||
r.group_id = rel.group_id
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||
|
||||
# 保存弱关系实体,设置 e.is_weak = true;不维护 e.relations 聚合字段
|
||||
WEAK_ENTITY_NODE_SAVE = """
|
||||
UNWIND $weak_entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
||||
SET e += {
|
||||
name: entity.name,
|
||||
group_id: entity.group_id,
|
||||
run_id: entity.run_id,
|
||||
description: entity.description,
|
||||
chunk_id: entity.chunk_id,
|
||||
dialog_id: entity.dialog_id
|
||||
}
|
||||
// Independent weak flag,仅标记弱关系,不再维护 relations 聚合字段
|
||||
SET e.is_weak = true
|
||||
RETURN e.id AS id
|
||||
"""
|
||||
|
||||
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true,不维护 e.relations 字段
|
||||
SAVE_STRONG_TRIPLE_ENTITIES = """
|
||||
UNWIND $items AS item
|
||||
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
||||
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET s.is_strong = true
|
||||
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
||||
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET o.is_strong = true
|
||||
"""
|
||||
|
||||
|
||||
DIALOGUE_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $dialogue_statement_edges AS edge
|
||||
// 支持按 uuid 或 ref_id 连接到 Dialogue,避免因来源 ID 不一致而断链
|
||||
MATCH (dialogue:Dialogue)
|
||||
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
|
||||
MATCH (statement:Statement {id: edge.target})
|
||||
// 仅按端点去重,关系属性可更新
|
||||
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
||||
SET e.uuid = edge.id,
|
||||
e.group_id = edge.group_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||
|
||||
|
||||
CHUNK_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $chunk_statement_edges AS edge
|
||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
|
||||
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
||||
SET e.group_id = edge.group_id,
|
||||
e.run_id = edge.run_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
STATEMENT_ENTITY_EDGE_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
|
||||
// Entities are shared across runs within a group; do not constrain by run_id
|
||||
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
|
||||
// Avoid duplicate edges across runs for same endpoints
|
||||
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
||||
SET r.group_id = rel.group_id,
|
||||
r.run_id = rel.run_id,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
r.connect_strength = rel.connect_strength
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS e, score
|
||||
WHERE e.name_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
# Embedding-based search: cosine similarity on Statement.statement_embedding
|
||||
STATEMENT_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS s, score
|
||||
WHERE s.statement_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR s.group_id = $group_id)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
|
||||
CHUNK_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS c, score
|
||||
WHERE c.chunk_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR c.group_id = $group_id)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
# 查询实体名称包含指定字符串的实体
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
WHERE ($group_id IS NULL OR e.group_id = $group_id)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
e.apply_id AS apply_id,
|
||||
e.user_id AS user_id,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.entity_idx AS entity_idx,
|
||||
e.statement_id AS statement_id,
|
||||
e.description AS description,
|
||||
e.aliases AS aliases,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.fact_summary AS fact_summary,
|
||||
e.connect_strength AS connect_strength,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT c.id) AS chunk_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.sequence_number AS sequence_number,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||
|
||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
|
||||
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
# # 同组group_id下按name contains召回补充
|
||||
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND toLower(e.name) CONTAINS toLower(row.name)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE ($group_id IS NULL OR d.group_id = $group_id)
|
||||
AND d.id = $dialog_id
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_CHUNK_BY_CHUNK_ID = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
AND c.id = $chunk_id
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.created_at AS created_at,
|
||||
c.expired_at AS expired_at,
|
||||
c.sequence_number AS sequence_number
|
||||
ORDER BY c.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL = """
|
||||
MATCH (s:Statement)
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
|
||||
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
collect(DISTINCT s.id) AS statement_ids
|
||||
ORDER BY datetime(s.created_at) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY s.created_at DESC, score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_G_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_L_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_G_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_L_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||
|
||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
|
||||
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
# # 同组group_id下按name contains召回补充
|
||||
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND toLower(e.name) CONTAINS toLower(row.name)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
|
||||
# 根据id修改句子的invalid_at的值
|
||||
UPDATE_STATEMENT_INVALID_AT = """
|
||||
MATCH (n:Statement {group_id: $group_id, id: $id})
|
||||
SET n.invalid_at = $new_invalid_at
|
||||
"""
|
||||
|
||||
# MemorySummary keyword search using fulltext index
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||
WHERE ($group_id IS NULL OR m.group_id = $group_id)
|
||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS m, score
|
||||
WHERE m.summary_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR m.group_id = $group_id)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
MEMORY_SUMMARY_NODE_SAVE = """
|
||||
UNWIND $summaries AS summary
|
||||
MERGE (m:MemorySummary {id: summary.id})
|
||||
SET m += {
|
||||
id: summary.id,
|
||||
name: summary.name,
|
||||
group_id: summary.group_id,
|
||||
user_id: summary.user_id,
|
||||
apply_id: summary.apply_id,
|
||||
run_id: summary.run_id,
|
||||
created_at: summary.created_at,
|
||||
expired_at: summary.expired_at,
|
||||
dialog_id: summary.dialog_id,
|
||||
chunk_ids: summary.chunk_ids,
|
||||
content: summary.content,
|
||||
summary_embedding: summary.summary_embedding,
|
||||
config_id: summary.config_id
|
||||
}
|
||||
RETURN m.id AS uuid
|
||||
"""
|
||||
|
||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $edges AS e
|
||||
MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
|
||||
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
|
||||
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
|
||||
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
|
||||
SET r.group_id = e.group_id,
|
||||
r.run_id = e.run_id,
|
||||
r.created_at = e.created_at,
|
||||
r.expired_at = e.expired_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
185
app/repositories/neo4j/dialog_repository.py
Normal file
185
app/repositories/neo4j/dialog_repository.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""对话仓储模块
|
||||
|
||||
本模块提供对话节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
DialogRepository: 对话仓储,管理DialogueNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import DialogueNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
"""对话仓储
|
||||
|
||||
管理对话节点的创建、查询、更新和删除操作。
|
||||
提供按group_id、user_id、ref_id等条件查询对话的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"Dialogue"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化对话仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "Dialogue")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> DialogueNode:
|
||||
"""将节点数据映射为对话实体
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
DialogueNode: 对话实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
return DialogueNode(**n)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据group_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据user_id查询对话
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"user_id": user_id}, limit=limit)
|
||||
|
||||
async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]:
|
||||
"""根据ref_id查询对话
|
||||
|
||||
ref_id是外部对话系统的引用ID,通常是唯一的。
|
||||
|
||||
Args:
|
||||
ref_id: 引用ID
|
||||
|
||||
Returns:
|
||||
Optional[DialogueNode]: 找到的对话,如果不存在则返回None
|
||||
"""
|
||||
results = await self.find({"ref_id": ref_id}, limit=1)
|
||||
return results[0] if results else None
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据group_id和user_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
user_id: 用户ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "user_id": user_id},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_recent_dialogs(
|
||||
self,
|
||||
group_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""查询最近的对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
days: 查询最近多少天的对话
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表,按创建时间倒序排列
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据config_id查询对话
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def find_by_config_and_group(
|
||||
self,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据config_id和group_id查询对话
|
||||
|
||||
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"config_id": config_id, "group_id": group_id},
|
||||
limit=limit
|
||||
)
|
||||
339
app/repositories/neo4j/entity_repository.py
Normal file
339
app/repositories/neo4j/entity_repository.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""实体仓储模块
|
||||
|
||||
本模块提供实体节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
EntityRepository: 实体仓储,管理ExtractedEntityNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
"""实体仓储
|
||||
|
||||
管理实体节点的创建、查询、更新和删除操作。
|
||||
提供按类型、名称、向量相似度等条件查询实体的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"ExtractedEntity"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化实体仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "ExtractedEntity")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> ExtractedEntityNode:
|
||||
"""将节点数据映射为实体对象
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
ExtractedEntityNode: 实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据实体类型查询
|
||||
|
||||
Args:
|
||||
entity_type: 实体类型(如"Person", "Organization"等)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"entity_type": entity_type}, limit=limit)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据group_id查询实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_name(
|
||||
self,
|
||||
name: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据名称查询实体
|
||||
|
||||
支持模糊匹配(CONTAINS)。
|
||||
|
||||
Args:
|
||||
name: 实体名称
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
where_clause = "n.name CONTAINS $name"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"name": name, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_related_entities(
|
||||
self,
|
||||
entity_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询相关实体
|
||||
|
||||
查询与指定实体有关系的其他实体。
|
||||
|
||||
Args:
|
||||
entity_id: 实体ID
|
||||
relation_type: 可选的关系类型过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 相关实体列表
|
||||
"""
|
||||
if relation_type:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
relation_type=relation_type,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
where_clause = "n.name_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
|
||||
"""根据陈述句ID查询实体
|
||||
|
||||
查询从指定陈述句中提取的所有实体。
|
||||
|
||||
Args:
|
||||
statement_id: 陈述句ID
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"statement_id": statement_id})
|
||||
|
||||
async def find_strong_entities(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询强连接的实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 强连接的实体列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
|
||||
"""统计各类型实体的数量
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 实体类型到数量的映射
|
||||
"""
|
||||
query = """
|
||||
MATCH (n:ExtractedEntity {group_id: $group_id})
|
||||
RETURN n.entity_type as entity_type, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return {r["entity_type"]: r["count"] for r in results}
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据config_id查询实体
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.name_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
216
app/repositories/neo4j/graph_saver.py
Normal file
216
app/repositories/neo4j/graph_saver.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from typing import List
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
ENTITY_RELATIONSHIP_SAVE,
|
||||
EXTRACTED_ENTITY_NODE_SAVE,
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
ENTITY_RELATIONSHIP_SAVE,
|
||||
EXTRACTED_ENTITY_NODE_SAVE,
|
||||
)
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
)
|
||||
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save entities and their relationships using graph models"""
|
||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||
all_relationships = []
|
||||
|
||||
for edge in entity_entity_edges:
|
||||
relationship = {
|
||||
'source_id': edge.source,
|
||||
'target_id': edge.target,
|
||||
'predicate': edge.relation_type,
|
||||
'statement_id': edge.source_statement_id,
|
||||
'value': edge.relation_value,
|
||||
'statement': edge.statement,
|
||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||
'created_at': edge.created_at.isoformat(),
|
||||
'expired_at': edge.expired_at.isoformat(),
|
||||
'run_id': edge.run_id,
|
||||
'group_id': edge.group_id,
|
||||
'user_id': edge.user_id,
|
||||
'apply_id': edge.apply_id,
|
||||
}
|
||||
all_relationships.append(relationship)
|
||||
|
||||
# Save entities
|
||||
if all_entities:
|
||||
entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities)
|
||||
if entity_uuids:
|
||||
print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save entity nodes to Neo4j")
|
||||
else:
|
||||
print("No entity nodes to save")
|
||||
|
||||
# Create relationships
|
||||
if all_relationships:
|
||||
relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships)
|
||||
if relationship_uuids:
|
||||
print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j")
|
||||
else:
|
||||
print("Failed to save entity relationships to Neo4j")
|
||||
else:
|
||||
print("No entity relationships to save")
|
||||
|
||||
|
||||
async def save_chunk_nodes(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save chunk nodes using graph models"""
|
||||
if not chunk_nodes:
|
||||
print("No chunk nodes to save")
|
||||
return
|
||||
|
||||
chunk_uuids = await add_chunk_nodes(chunk_nodes, connector)
|
||||
if chunk_uuids:
|
||||
print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save chunk nodes to Neo4j")
|
||||
|
||||
|
||||
async def save_statement_chunk_edges(
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-chunk edges using graph models"""
|
||||
if not statement_chunk_edges:
|
||||
return
|
||||
|
||||
all_sc_edges = []
|
||||
for edge in statement_chunk_edges:
|
||||
all_sc_edges.append({
|
||||
"id": edge.id,
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"run_id": edge.run_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
})
|
||||
|
||||
try:
|
||||
await connector.execute_query(
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
chunk_statement_edges=all_sc_edges
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def save_statement_entity_edges(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-entity edges using graph models"""
|
||||
if not statement_entity_edges:
|
||||
print("No statement-entity edges to save")
|
||||
return
|
||||
|
||||
all_se_edges = []
|
||||
for edge in statement_entity_edges:
|
||||
edge_data = {
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"run_id": edge.run_id,
|
||||
"connect_strength": edge.connect_strength,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
}
|
||||
all_se_edges.append(edge_data)
|
||||
|
||||
if all_se_edges:
|
||||
try:
|
||||
await connector.execute_query(
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
relationships=all_se_edges
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
statement_nodes: List of StatementNode objects to save
|
||||
entity_nodes: List of ExtractedEntityNode objects to save
|
||||
entity_edges: List of EntityEntityEdge objects to save
|
||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Save all dialogue nodes in batch
|
||||
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
||||
if dialogue_uuids:
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
else:
|
||||
print("Failed to save dialogues to Neo4j")
|
||||
return False
|
||||
|
||||
# Save all chunk nodes in batch
|
||||
await save_chunk_nodes(chunk_nodes, connector)
|
||||
|
||||
# Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
||||
if statement_uuids:
|
||||
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save statement nodes to Neo4j")
|
||||
return False
|
||||
else:
|
||||
print("No statement nodes to save")
|
||||
|
||||
# Save entities and relationships
|
||||
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
||||
print("Successfully saved entities and relationships to Neo4j")
|
||||
|
||||
# Save new edges
|
||||
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
||||
await save_statement_entity_edges(statement_entity_edges, connector)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
584
app/repositories/neo4j/graph_search.py
Normal file
584
app/repositories/neo4j/graph_search.py
Normal file
@@ -0,0 +1,584 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
- Chunks: matches s.content CONTAINS q (from Statement nodes)
|
||||
- Summaries: matches ms.content CONTAINS q
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
group_id: Optional group filter
|
||||
limit: Max results per category
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
Returns:
|
||||
Dictionary with search results per category
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Build results dictionary
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
- Filters by group_id if provided
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
|
||||
# Get embedding for the query
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||
embedding = embeddings[0]
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
# Statements (embedding)
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
# Chunks (embedding)
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
# Entities
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
# Memory summaries
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
# Execute all queries in parallel
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
"statements": [],
|
||||
"chunks": [],
|
||||
"entities": [],
|
||||
"summaries": [],
|
||||
}
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
|
||||
- 保留并发控制与返回结构(incoming_id -> [db_entity_props...]);
|
||||
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
|
||||
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
|
||||
|
||||
返回:incoming_id -> [db_entity_props...]
|
||||
"""
|
||||
|
||||
if not entities:
|
||||
return {}
|
||||
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def _query_by_name(incoming: Dict[str, Any]) -> tuple[str, List[Dict[str, Any]]]:
|
||||
async with sem:
|
||||
inc_id = incoming.get("id") or "__unknown__"
|
||||
name = (incoming.get("name") or "").strip()
|
||||
if not name:
|
||||
return inc_id, []
|
||||
try:
|
||||
# 全文索引按名称检索(包含 CONTAINS 语义)
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name,
|
||||
group_id=group_id,
|
||||
limit=100,
|
||||
)
|
||||
except Exception:
|
||||
rows = []
|
||||
|
||||
# 可选本地类型过滤(若输入实体提供类型)
|
||||
typ = incoming.get("entity_type")
|
||||
if typ:
|
||||
try:
|
||||
rows = [r for r in rows if (r.get("entity_type") == typ)]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 注入 incoming_id 以保持兼容下游合并逻辑
|
||||
for r in rows:
|
||||
r["incoming_id"] = inc_id
|
||||
|
||||
# 简单的降级:若为空且允许 fallback,可按小写名再次查询
|
||||
if use_contains_fallback and not rows and name:
|
||||
try:
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name.lower(),
|
||||
group_id=group_id,
|
||||
limit=100,
|
||||
)
|
||||
for r in rows:
|
||||
r["incoming_id"] = inc_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return inc_id, rows
|
||||
|
||||
tasks = [_query_by_name(e) for e in entities]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
merged: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for res in results:
|
||||
if isinstance(res, Exception):
|
||||
# 静默跳过单条失败
|
||||
continue
|
||||
inc_id, rows = res
|
||||
inc_id = inc_id or "__unknown__"
|
||||
merged.setdefault(inc_id, [])
|
||||
existing_ids = {x.get("id") for x in merged[inc_id]}
|
||||
for rec in rows:
|
||||
if rec.get("id") not in existing_ids:
|
||||
merged[inc_id].append(rec)
|
||||
return merged
|
||||
|
||||
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
|
||||
- Matches dialogues with dialog_id
|
||||
- Optionally filters by group_id
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
group_id=group_id,
|
||||
dialog_id=dialog_id,
|
||||
limit=limit,
|
||||
)
|
||||
return {"dialogues": dialogues}
|
||||
|
||||
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
group_id=group_id,
|
||||
chunk_id=chunk_id,
|
||||
limit=limit,
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
114
app/repositories/neo4j/neo4j_connector.py
Normal file
114
app/repositories/neo4j/neo4j_connector.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j连接器模块
|
||||
|
||||
本模块提供Neo4j图数据库的连接和查询功能。
|
||||
从 app/core/memory/src/database/neo4j_connector.py 迁移而来。
|
||||
|
||||
Classes:
|
||||
Neo4jConnector: Neo4j数据库连接器,提供异步查询接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Neo4jConnector:
|
||||
"""Neo4j数据库连接器
|
||||
|
||||
提供与Neo4j图数据库的连接和查询功能。
|
||||
使用异步驱动程序以支持高并发操作。
|
||||
|
||||
Attributes:
|
||||
driver: Neo4j异步驱动程序实例
|
||||
|
||||
Methods:
|
||||
close: 关闭数据库连接
|
||||
execute_query: 执行Cypher查询
|
||||
delete_group: 删除指定组的所有数据
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Neo4j连接器
|
||||
|
||||
从配置文件和环境变量中读取连接信息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果NEO4J_PASSWORD环境变量未设置
|
||||
"""
|
||||
# 从全局配置和环境变量获取 Neo4j 配置
|
||||
uri = settings.NEO4J_URI
|
||||
username = settings.NEO4J_USERNAME
|
||||
password = settings.NEO4J_PASSWORD
|
||||
|
||||
if not password:
|
||||
raise RuntimeError(
|
||||
"NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running."
|
||||
)
|
||||
self.driver = AsyncGraphDatabase.driver(
|
||||
uri,
|
||||
auth=basic_auth(username, password)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接
|
||||
|
||||
释放数据库连接资源。应在应用程序关闭时调用。
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
"""执行Cypher查询
|
||||
|
||||
Args:
|
||||
query: Cypher查询语句
|
||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> results = await connector.execute_query(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name="Alice"
|
||||
... )
|
||||
"""
|
||||
result = await self.driver.execute_query(
|
||||
query,
|
||||
database="neo4j",
|
||||
**kwargs
|
||||
)
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def delete_group(self, group_id: str):
|
||||
"""删除指定组的所有数据
|
||||
|
||||
删除所有属于指定group_id的节点和边。
|
||||
这是一个危险操作,会永久删除数据。
|
||||
|
||||
Args:
|
||||
group_id: 要删除的组ID
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> await connector.delete_group("group_123")
|
||||
Group group_123 deleted.
|
||||
"""
|
||||
# 删除节点(DETACH DELETE会同时删除相关的边)
|
||||
await self.driver.execute_query(
|
||||
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
)
|
||||
# 删除独立的边(如果有的话)
|
||||
await self.driver.execute_query(
|
||||
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
)
|
||||
print(f"Group {group_id} deleted.")
|
||||
319
app/repositories/neo4j/statement_repository.py
Normal file
319
app/repositories/neo4j/statement_repository.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""陈述句仓储模块
|
||||
|
||||
本模块提供陈述句节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import StatementNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
|
||||
|
||||
class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
"""陈述句仓储
|
||||
|
||||
管理陈述句节点的创建、查询、更新和删除操作。
|
||||
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"Statement"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化陈述句仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "Statement")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> StatementNode:
|
||||
"""将节点数据映射为陈述句实体
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
StatementNode: 陈述句实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
if n.get('valid_at') and isinstance(n['valid_at'], str):
|
||||
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
|
||||
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
|
||||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||||
|
||||
# 处理temporal_info字段
|
||||
if isinstance(n.get('temporal_info'), dict):
|
||||
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
|
||||
elif not n.get('temporal_info'):
|
||||
# 如果没有temporal_info,创建一个默认的
|
||||
n['temporal_info'] = TemporalInfo()
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
|
||||
"""根据chunk_id查询陈述句
|
||||
|
||||
Args:
|
||||
chunk_id: 分块ID
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"chunk_id": chunk_id})
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
|
||||
"""根据group_id查询陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clause = "n.statement_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def search_by_keyword(
|
||||
self,
|
||||
keyword: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[StatementNode]:
|
||||
"""基于关键词搜索陈述句
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clause = "n.statement CONTAINS $keyword"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"keyword": keyword, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_temporal_range(
|
||||
self,
|
||||
group_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据时间范围查询陈述句
|
||||
|
||||
查询在指定时间范围内有效的陈述句。
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clauses = ["n.group_id = $group_id"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("n.valid_at >= $start_date")
|
||||
params["start_date"] = start_date.isoformat()
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
|
||||
params["end_date"] = end_date.isoformat()
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_strong_statements(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""查询强连接的陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 强连接的陈述句列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据config_id查询陈述句
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.statement_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
Reference in New Issue
Block a user