[MODIFY] Code optimization

This commit is contained in:
Mark
2025-12-15 14:09:43 +08:00
parent d2a630addb
commit a4e276ab27
157 changed files with 15976 additions and 3601 deletions

View File

@@ -7,7 +7,7 @@ Classes:
StatementRepository: 陈述句仓储管理StatementNode的CRUD操作
"""
from typing import List, Optional, Dict
from typing import List, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
@@ -76,244 +76,3 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
List[StatementNode]: 陈述句列表
"""
return await self.find({"chunk_id": chunk_id})
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
"""根据group_id查询陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def search_by_embedding(
self,
embedding: List[float],
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句
使用余弦相似度计算查询向量与陈述句向量的相似度。
Args:
embedding: 查询向量
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clause = "n.statement_embedding IS NOT NULL"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]
async def search_by_keyword(
self,
keyword: str,
group_id: Optional[str] = None,
limit: int = 50
) -> List[StatementNode]:
"""基于关键词搜索陈述句
Args:
keyword: 搜索关键词
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clause = "n.statement CONTAINS $keyword"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
RETURN n
LIMIT $limit
"""
params = {"keyword": keyword, "limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_by_temporal_range(
self,
group_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100
) -> List[StatementNode]:
"""根据时间范围查询陈述句
查询在指定时间范围内有效的陈述句。
Args:
group_id: 组ID
start_date: 开始日期(可选)
end_date: 结束日期(可选)
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clauses = ["n.group_id = $group_id"]
params = {"group_id": group_id, "limit": limit}
if start_date:
where_clauses.append("n.valid_at >= $start_date")
params["start_date"] = start_date.isoformat()
if end_date:
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
params["end_date"] = end_date.isoformat()
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_strong_statements(
self,
group_id: str,
limit: int = 100
) -> List[StatementNode]:
"""查询强连接的陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 强连接的陈述句列表
"""
return await self.find(
{"group_id": group_id, "connect_strength": "Strong"},
limit=limit
)
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[StatementNode]:
"""根据config_id查询陈述句
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def search_by_embedding_with_config(
self,
embedding: List[float],
config_id: Optional[str] = None,
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句,可选择按config_id过滤
使用余弦相似度计算查询向量与陈述句向量的相似度。
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
Args:
embedding: 查询向量
config_id: 可选的配置ID过滤
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clauses = ["n.statement_embedding IS NOT NULL"]
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if config_id:
where_clauses.append("n.config_id = $config_id")
params["config_id"] = config_id
if group_id:
where_clauses.append("n.group_id = $group_id")
params["group_id"] = group_id
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]