[MODIFY] Code optimization
This commit is contained in:
@@ -27,7 +27,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
|
||||
edges: List[dict] = []
|
||||
for chunk in chunks:
|
||||
for stmt in getattr(chunk, "statements", []) or []:
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode()).hexdigest()
|
||||
edge = {
|
||||
"id": stable_edge_id,
|
||||
"source": chunk.id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储基类模块
|
||||
|
||||
本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。
|
||||
@@ -57,9 +56,17 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
CREATE (n:{self.node_label} $props)
|
||||
RETURN n
|
||||
"""
|
||||
# 使用model_dump()获取所有字段,包括aliases
|
||||
props = entity.model_dump()
|
||||
|
||||
# 确保aliases字段存在且为列表(针对ExtractedEntity节点)
|
||||
if hasattr(entity, 'aliases'):
|
||||
if props.get('aliases') is None:
|
||||
props['aliases'] = []
|
||||
|
||||
result = await self.connector.execute_query(
|
||||
query,
|
||||
props=entity.model_dump()
|
||||
props=props
|
||||
)
|
||||
return entity
|
||||
|
||||
@@ -97,10 +104,18 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
SET n += $props
|
||||
RETURN n
|
||||
"""
|
||||
# 使用model_dump()获取所有字段,包括aliases
|
||||
props = entity.model_dump()
|
||||
|
||||
# 确保aliases字段存在且为列表(针对ExtractedEntity节点)
|
||||
if hasattr(entity, 'aliases'):
|
||||
if props.get('aliases') is None:
|
||||
props['aliases'] = []
|
||||
|
||||
await self.connector.execute_query(
|
||||
query,
|
||||
id=entity.id,
|
||||
props=entity.model_dump()
|
||||
props=props
|
||||
)
|
||||
return entity
|
||||
|
||||
@@ -142,7 +157,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
... )
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters]
|
||||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
query = f"""
|
||||
|
||||
@@ -85,7 +85,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
|
||||
THEN CASE
|
||||
WHEN e.aliases IS NULL THEN entity.aliases
|
||||
ELSE reduce(acc = [], alias IN (e.aliases + entity.aliases) |
|
||||
CASE WHEN alias IN acc THEN acc ELSE acc + alias END)
|
||||
END
|
||||
ELSE e.aliases END,
|
||||
e.name_embedding = CASE
|
||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||
@@ -682,3 +686,63 @@ SET r.group_id = e.group_id,
|
||||
r.expired_at = e.expired_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
|
||||
# Entity Merge Query
|
||||
MERGE_ENTITIES = """
|
||||
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
||||
MATCH (losing:ExtractedEntity {id: $losing_id})
|
||||
|
||||
// 更新canonical实体的aliases
|
||||
SET canonical.aliases = $merged_aliases
|
||||
|
||||
// 转移所有从losing出发的关系到canonical
|
||||
WITH canonical, losing
|
||||
OPTIONAL MATCH (losing)-[r]->(target)
|
||||
WHERE NOT (canonical)-[:RELATES_TO]->(target)
|
||||
FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
CREATE (canonical)-[:RELATES_TO {
|
||||
id: rel.id,
|
||||
relation_type: rel.relation_type,
|
||||
relation_value: rel.relation_value,
|
||||
statement: rel.statement,
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
created_at: rel.created_at,
|
||||
expired_at: rel.expired_at
|
||||
}]->(target)
|
||||
)
|
||||
|
||||
// 转移所有指向losing的关系到canonical
|
||||
WITH canonical, losing
|
||||
OPTIONAL MATCH (source)-[r]->(losing)
|
||||
WHERE NOT (source)-[:RELATES_TO]->(canonical)
|
||||
FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
CREATE (source)-[:RELATES_TO {
|
||||
id: rel.id,
|
||||
relation_type: rel.relation_type,
|
||||
relation_value: rel.relation_value,
|
||||
statement: rel.statement,
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
created_at: rel.created_at,
|
||||
expired_at: rel.expired_at
|
||||
}]->(canonical)
|
||||
)
|
||||
|
||||
// 删除losing实体及其所有关系
|
||||
WITH losing
|
||||
DETACH DELETE losing
|
||||
|
||||
RETURN count(losing) as deleted
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""实体仓储模块
|
||||
|
||||
本模块提供实体节点的数据访问功能。
|
||||
@@ -7,7 +6,7 @@ Classes:
|
||||
EntityRepository: 实体仓储,管理ExtractedEntityNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
@@ -49,9 +48,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
if n.get('expired_at') and isinstance(n.get('expired_at'), str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
# 确保aliases字段存在且为列表
|
||||
if 'aliases' not in n or n['aliases'] is None:
|
||||
n['aliases'] = []
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
@@ -66,274 +69,4 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
"""
|
||||
return await self.find({"entity_type": entity_type}, limit=limit)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据group_id查询实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_name(
|
||||
self,
|
||||
name: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据名称查询实体
|
||||
|
||||
支持模糊匹配(CONTAINS)。
|
||||
|
||||
Args:
|
||||
name: 实体名称
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
where_clause = "n.name CONTAINS $name"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"name": name, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_related_entities(
|
||||
self,
|
||||
entity_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询相关实体
|
||||
|
||||
查询与指定实体有关系的其他实体。
|
||||
|
||||
Args:
|
||||
entity_id: 实体ID
|
||||
relation_type: 可选的关系类型过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 相关实体列表
|
||||
"""
|
||||
if relation_type:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
relation_type=relation_type,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
where_clause = "n.name_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
|
||||
"""根据陈述句ID查询实体
|
||||
|
||||
查询从指定陈述句中提取的所有实体。
|
||||
|
||||
Args:
|
||||
statement_id: 陈述句ID
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"statement_id": statement_id})
|
||||
|
||||
async def find_strong_entities(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询强连接的实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 强连接的实体列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
|
||||
"""统计各类型实体的数量
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 实体类型到数量的映射
|
||||
"""
|
||||
query = """
|
||||
MATCH (n:ExtractedEntity {group_id: $group_id})
|
||||
RETURN n.entity_type as entity_type, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return {r["entity_type"]: r["count"] for r in results}
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据config_id查询实体
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.name_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ Classes:
|
||||
StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
@@ -76,244 +76,3 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"chunk_id": chunk_id})
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
|
||||
"""根据group_id查询陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clause = "n.statement_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def search_by_keyword(
|
||||
self,
|
||||
keyword: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[StatementNode]:
|
||||
"""基于关键词搜索陈述句
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clause = "n.statement CONTAINS $keyword"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"keyword": keyword, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_temporal_range(
|
||||
self,
|
||||
group_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据时间范围查询陈述句
|
||||
|
||||
查询在指定时间范围内有效的陈述句。
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clauses = ["n.group_id = $group_id"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("n.valid_at >= $start_date")
|
||||
params["start_date"] = start_date.isoformat()
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
|
||||
params["end_date"] = end_date.isoformat()
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_strong_statements(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""查询强连接的陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 强连接的陈述句列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据config_id查询陈述句
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.statement_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user