* 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索)
294 lines
10 KiB
Python
294 lines
10 KiB
Python
from app.repositories import Neo4jConnector
|
||
|
||
neo4j_connector = Neo4jConnector()
|
||
|
||
async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||
"""
|
||
Update Neo4j data based on query criteria and update parameters
|
||
|
||
Args:
|
||
neo4j_dict_data: find
|
||
update_databases: update
|
||
"""
|
||
try:
|
||
# 构建WHERE条件 - 只使用elementId
|
||
where_conditions = []
|
||
params = {}
|
||
|
||
# 优先使用id作为elementId进行查询
|
||
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
|
||
where_conditions.append(f"elementId(e) = $param_id")
|
||
params['param_id'] = neo4j_dict_data['id']
|
||
else:
|
||
# 如果没有id,使用其他字段作为条件
|
||
for key, value in neo4j_dict_data.items():
|
||
if value is not None:
|
||
param_name = f"param_{key}"
|
||
where_conditions.append(f"e.{key} = ${param_name}")
|
||
params[param_name] = value
|
||
|
||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||
|
||
# 构建SET条件 - 排除id字段
|
||
set_conditions = []
|
||
for key, value in update_databases.items():
|
||
if value is not None and key != 'id': # 不更新id字段
|
||
param_name = f"update_{key}"
|
||
set_conditions.append(f"e.{key} = ${param_name}")
|
||
params[param_name] = value
|
||
|
||
set_clause = ", ".join(set_conditions)
|
||
|
||
if not set_clause:
|
||
print("警告: 没有需要更新的字段")
|
||
return False
|
||
|
||
# 构建Cypher查询
|
||
cypher_query = f"""
|
||
MATCH (e:ExtractedEntity)
|
||
WHERE {where_clause}
|
||
SET {set_clause}
|
||
RETURN count(e) as updated_count, collect(e.name) as updated_names
|
||
"""
|
||
|
||
print(f"\n执行Cypher查询: {cypher_query}")
|
||
print(f"参数: {params}")
|
||
|
||
# 执行更新
|
||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||
|
||
if result:
|
||
updated_count = result[0].get('updated_count', 0)
|
||
updated_names = result[0].get('updated_names', [])
|
||
print(f"成功更新 {updated_count} 个节点")
|
||
if updated_names:
|
||
print(f"更新的实体名称: {updated_names}")
|
||
return updated_count > 0
|
||
else:
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"更新过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
|
||
"""
|
||
Update Neo4j data based on query criteria and update parameters
|
||
|
||
Args:
|
||
neo4j_dict_data: find
|
||
update_databases: update
|
||
"""
|
||
try:
|
||
# 构建WHERE条件 - 只使用elementId
|
||
where_conditions = []
|
||
params = {}
|
||
|
||
# 优先使用id作为elementId进行查询
|
||
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
|
||
where_conditions.append(f"elementId(r) = $param_id")
|
||
params['param_id'] = neo4j_dict_data['id']
|
||
else:
|
||
# 如果没有id,使用其他字段作为条件
|
||
for key, value in neo4j_dict_data.items():
|
||
if value is not None:
|
||
param_name = f"param_{key}"
|
||
where_conditions.append(f"r.{key} = ${param_name}")
|
||
params[param_name] = value
|
||
|
||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||
|
||
# 构建SET条件 - 排除id字段
|
||
set_conditions = []
|
||
for key, value in update_databases.items():
|
||
if value is not None and key != 'id': # 不更新id字段
|
||
param_name = f"update_{key}"
|
||
set_conditions.append(f"r.{key} = ${param_name}")
|
||
params[param_name] = value
|
||
|
||
set_clause = ", ".join(set_conditions)
|
||
|
||
if not set_clause:
|
||
print("警告: 没有需要更新的字段")
|
||
return False
|
||
|
||
# 构建Cypher查询
|
||
cypher_query = f"""
|
||
MATCH (n)-[r]->(m)
|
||
WHERE {where_clause}
|
||
SET {set_clause}
|
||
RETURN count(r) as updated_count, collect(type(r)) as relation_types
|
||
"""
|
||
|
||
print(f"\n执行Cypher查询: {cypher_query}")
|
||
print(f"参数: {params}")
|
||
|
||
# 执行更新
|
||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||
|
||
if result:
|
||
updated_count = result[0].get('updated_count', 0)
|
||
updated_names = result[0].get('updated_names', [])
|
||
print(f"成功更新 {updated_count} 个节点")
|
||
if updated_names:
|
||
print(f"更新的实体名称: {updated_names}")
|
||
return updated_count > 0
|
||
else:
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"更新过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
def map_field_names(data_dict):
|
||
mapped_dict = {}
|
||
has_name_field = False
|
||
|
||
# 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回)
|
||
def extract_value(value):
|
||
if isinstance(value, list) and len(value) > 0:
|
||
# 如果是数组 [old_value, new_value],取新值(最后一个)
|
||
return value[-1]
|
||
return value
|
||
|
||
# 第一遍:检查是否有name相关字段
|
||
for key, value in data_dict.items():
|
||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||
has_name_field = True
|
||
break
|
||
|
||
print(f"字段检查: has_name_field = {has_name_field}")
|
||
|
||
# 第二遍:根据规则映射和过滤字段
|
||
for key, value in data_dict.items():
|
||
# 提取实际值(处理数组格式)
|
||
actual_value = extract_value(value)
|
||
|
||
if key == 'entity2.name' or key == 'entity2_name':
|
||
# 将 entity2.name 映射为 name
|
||
mapped_dict['name'] = actual_value
|
||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||
elif key == 'entity1.name' or key == 'entity1_name':
|
||
# 将 entity1.name 映射为 name
|
||
mapped_dict['name'] = actual_value
|
||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||
elif key == 'entity1.description':
|
||
# 将 entity1.description 映射为 description
|
||
mapped_dict['description'] = actual_value
|
||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||
elif key == 'entity2.description':
|
||
# 将 entity2.description 映射为 description
|
||
mapped_dict['description'] = actual_value
|
||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||
elif key == 'relationship_type':
|
||
# 跳过relationship_type字段
|
||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||
continue
|
||
elif key == 'entity1_name':
|
||
if has_name_field:
|
||
# 如果有name字段,跳过entity1_name
|
||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||
continue
|
||
else:
|
||
# 如果没有name字段,保留entity1_name
|
||
mapped_dict[key] = actual_value
|
||
print(f"字段保留: {key} (值: {value} -> {actual_value})")
|
||
elif key == 'entity2_name':
|
||
if has_name_field:
|
||
# 如果有name字段,跳过entity2_name
|
||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||
continue
|
||
else:
|
||
# 即使没有name字段,也不使用entity2_name(根据需求)
|
||
print(f"字段过滤: 跳过不推荐的字段 '{key}'")
|
||
continue
|
||
elif '.' not in key:
|
||
# 不包含点号的其他字段直接保留
|
||
mapped_dict[key] = actual_value
|
||
if isinstance(value, list):
|
||
print(f"字段保留: {key} (数组值: {value} -> {actual_value})")
|
||
else:
|
||
print(f"字段保留: {key}")
|
||
else:
|
||
# 其他包含点号的字段跳过并警告
|
||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||
|
||
print(f"字段映射结果: {mapped_dict}")
|
||
return mapped_dict
|
||
async def neo4j_data(solved_data):
|
||
"""
|
||
Process the resolved data and update the Neo4j database
|
||
Args:
|
||
Solved_data: Solution Data List
|
||
Returns:
|
||
Int: Number of successfully updated records
|
||
"""
|
||
success_count = 0
|
||
|
||
ori_entity = {}
|
||
updata_entity = {}
|
||
ori_edge = {}
|
||
updata_edge = {}
|
||
ori_expired_at={}
|
||
updat_expired_at={}
|
||
for i in solved_data:
|
||
databasets = i['data']
|
||
for key, values in databasets.items():
|
||
if str(values)=='NONE':
|
||
continue
|
||
if isinstance(values, list):
|
||
if key == 'description':
|
||
ori_entity[key] = values[0]
|
||
updata_entity[key] = values[1]
|
||
if key == 'entity2_name' or key == 'entity1_name':
|
||
key = 'name'
|
||
ori_entity[key] = values[0]
|
||
updata_entity[key] = values[1]
|
||
ori_expired_at[key] = values[0]
|
||
if key == 'statement':
|
||
ori_edge[key] = values[0]
|
||
updata_edge[key] = values[1]
|
||
if key=='expired_at':
|
||
updat_expired_at[key] = values[1]
|
||
|
||
elif key == 'id':
|
||
ori_edge[key] = values
|
||
updata_edge[key] = values
|
||
|
||
ori_entity[key] = values
|
||
updata_entity[key] = values
|
||
|
||
ori_expired_at[key] = values
|
||
elif key == 'rel_id':
|
||
key='id'
|
||
ori_edge[key] = values
|
||
updata_edge[key] = values
|
||
|
||
ori_entity[key] = values
|
||
updata_entity[key] = values
|
||
|
||
ori_expired_at[key] = values
|
||
|
||
|
||
print(ori_entity)
|
||
print(updata_entity)
|
||
print(100*'-')
|
||
print(ori_edge)
|
||
print(updata_edge)
|
||
expired_at_ = updat_expired_at.get('expired_at', None)
|
||
if expired_at_ is not None:
|
||
await update_neo4j_data(ori_expired_at, updat_expired_at)
|
||
success_count += 1
|
||
if ori_entity != updata_entity:
|
||
await update_neo4j_data(ori_entity, updata_entity)
|
||
success_count += 1
|
||
if ori_edge != updata_edge:
|
||
await update_neo4j_data_edge(ori_edge, updata_edge)
|
||
success_count += 1
|
||
|
||
return success_count
|
||
|