Merge #32 into develop from fix/memory_reflection

更新 self_reflexion.py

* fix/memory_reflection: (50 commits squashed)

  - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)

  - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/controllers/memory_reflection_controller.py
    #      api/app/schemas/memory_reflection_schemas.py

  - 反思优化

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection

  - 统一输出

  - 统一输出

  - 统一输出

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/controllers/memory_reflection_controller.py

  - 统一输出

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection

  - 统一输出

  - 反思速度提升,从4分钟优化成1分10-40秒

  - 反思速度提升,从4分钟优化成1分10-40秒

  - 反思速度提升,从4分钟优化成1分10-40秒

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 反思速度提升,从4分钟优化成1分10-40秒

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
    
    # Conflicts:
    #	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 更新 self_reflexion.py

  - 反思图谱添加边的修改

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
    
    # Conflicts:
    #	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - update
    
    
    # Conflicts:  
    #      api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
    #      api/app/core/memory/utils/prompt/prompts/reflexion.jinja2

Signed-off-by: aliyun8644380055 <accounts_68c0f5d519f260d93ee2997e@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/32
This commit is contained in:
李新月
2025-12-22 12:26:45 +00:00
committed by 孙科
parent 6d41a12c8f
commit cd644b6eab
6 changed files with 254 additions and 130 deletions

View File

@@ -783,7 +783,9 @@ neo4j_query_part = """
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
rel as relationship,
rel.predicate as predicate,
rel.statement as relationship,
rel.statement_id as relationship_statement_id,
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
other as entity2
"""
@@ -799,7 +801,9 @@ neo4j_query_all = """
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
rel as relationship,
rel.predicate as predicate,
rel.statement as relationship,
rel.statement_id as relationship_statement_id,
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
other as entity2
"""

View File

@@ -67,11 +67,81 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
traceback.print_exc()
return False
async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
"""
Update Neo4j data based on query criteria and update parameters
Args:
neo4j_dict_data: find
update_databases: update
"""
try:
# 构建WHERE条件
where_conditions = []
params = {}
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"r.{key} = ${param_name}")
params[param_name] = value
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 构建SET条件
set_conditions = []
for key, value in update_databases.items():
if value is not None:
param_name = f"update_{key}"
set_conditions.append(f"r.{key} = ${param_name}")
params[param_name] = value
set_clause = ", ".join(set_conditions)
if not set_clause:
print("警告: 没有需要更新的字段")
return False
# 构建Cypher查询
cypher_query = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause}
SET {set_clause}
RETURN count(r) as updated_count, collect(type(r)) as relation_types
"""
print(f"\n执行Cypher查询: {cypher_query}")
print(f"参数: {params}")
# 执行更新
result = await neo4j_connector.execute_query(cypher_query, **params)
if result:
updated_count = result[0].get('updated_count', 0)
updated_names = result[0].get('updated_names', [])
print(f"成功更新 {updated_count} 个节点")
if updated_names:
print(f"更新的实体名称: {updated_names}")
return updated_count > 0
else:
return False
except Exception as e:
print(f"更新过程中出现错误: {e}")
import traceback
traceback.print_exc()
return False
def map_field_names(data_dict):
mapped_dict = {}
has_name_field = False
# 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回)
def extract_value(value):
if isinstance(value, list) and len(value) > 0:
# 如果是数组 [old_value, new_value],取新值(最后一个)
return value[-1]
return value
# 第一遍检查是否有name相关字段
for key, value in data_dict.items():
if key in ['name', 'entity2.name', 'entity1.name']:
@@ -82,22 +152,25 @@ def map_field_names(data_dict):
# 第二遍:根据规则映射和过滤字段
for key, value in data_dict.items():
# 提取实际值(处理数组格式)
actual_value = extract_value(value)
if key == 'entity2.name' or key == 'entity2_name':
# 将 entity2.name 映射为 name
mapped_dict['name'] = value
print(f"字段名映射: {key} -> name")
mapped_dict['name'] = actual_value
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
elif key == 'entity1.name' or key == 'entity1_name':
# 将 entity1.name 映射为 name
mapped_dict['name'] = value
print(f"字段名映射: {key} -> name")
mapped_dict['name'] = actual_value
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
elif key == 'entity1.description':
# 将 entity1.description 映射为 description
mapped_dict['description'] = value
print(f"字段名映射: {key} -> description")
mapped_dict['description'] = actual_value
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
elif key == 'entity2.description':
# 将 entity2.description 映射为 description
mapped_dict['description'] = value
print(f"字段名映射: {key} -> description")
mapped_dict['description'] = actual_value
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
elif key == 'relationship_type':
# 跳过relationship_type字段
print(f"字段过滤: 跳过不需要的字段 '{key}'")
@@ -109,8 +182,8 @@ def map_field_names(data_dict):
continue
else:
# 如果没有name字段保留entity1_name
mapped_dict[key] = value
print(f"字段保留: {key}")
mapped_dict[key] = actual_value
print(f"字段保留: {key} (值: {value} -> {actual_value})")
elif key == 'entity2_name':
if has_name_field:
# 如果有name字段跳过entity2_name
@@ -122,7 +195,11 @@ def map_field_names(data_dict):
continue
elif '.' not in key:
# 不包含点号的其他字段直接保留
mapped_dict[key] = value
mapped_dict[key] = actual_value
if isinstance(value, list):
print(f"字段保留: {key} (数组值: {value} -> {actual_value})")
else:
print(f"字段保留: {key}")
else:
# 其他包含点号的字段跳过并警告
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
@@ -139,89 +216,57 @@ async def neo4j_data(solved_data):
"""
success_count = 0
ori_entity = {}
updata_entity = {}
ori_edge = {}
updata_edge = {}
ori_expired_at={}
updat_expired_at={}
for i in solved_data:
neo4j_dict_data = {}
update_databases = {}
results = i['results']
for data in results:
resolved = data.get('resolved')
if not resolved:
print("跳过resolved为None")
databasets = i['data']
for key, values in databasets.items():
if str(values)=='NONE':
continue
if isinstance(values, list):
if key == 'description':
ori_entity[key] = values[0]
updata_entity[key] = values[1]
if key == 'entity2_name' or key == 'entity1_name':
key = 'name'
ori_entity[key] = values[0]
updata_entity[key] = values[1]
ori_expired_at[key] = values[0]
if key == 'statement':
ori_edge[key] = values[0]
updata_edge[key] = values[1]
if key=='expired_at':
updat_expired_at[key] = values[1]
try:
change_list = resolved.get('change', [])
except (AttributeError, TypeError):
change_list = []
elif key == 'statement_id':
ori_edge[key] = values
updata_edge[key] = values
if change_list == []:
print("跳过change_list为空")
continue
ori_entity[key] = values
updata_entity[key] = values
if change_list and len(change_list) > 0:
change = change_list[0]
print(f"change: {change}")
field_data = change.get('field', [])
print(f"field_data: {field_data}")
print(f"field_data type: {type(field_data)}")
# 字段名映射和过滤函数
ori_expired_at[key] = values
# 处理field数据可能是字典或列表
if isinstance(field_data, dict):
# 如果是字典,映射字段名后更新
mapped_data = map_field_names(field_data)
update_databases.update(mapped_data)
elif isinstance(field_data, list):
# 如果是列表,遍历每个字典并更新
for field_item in field_data:
if isinstance(field_item, dict):
mapped_item = map_field_names(field_item)
update_databases.update(mapped_item)
else:
print(f"警告: field_item不是字典: {field_item}")
else:
print(f"警告: field_data类型不支持: {type(field_data)}")
if 'entity1_name' in data:
data['name'] = data.pop('entity1_name')
if 'entity2_name' in data:
data.pop('entity2_name', None)
resolved_memory = resolved.get('resolved_memory', {})
entity2 = None
if isinstance(resolved_memory, dict):
entity2 = resolved_memory.get('entity2')
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
stat_id = resolved.get('original_memory_id')
# 安全地获取description
statement_id = None
if isinstance(resolved_memory, dict):
statement_id = resolved_memory.get('statement_id')
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
if statement_id and 'id' not in neo4j_dict_data:
neo4j_dict_data['id'] = stat_id
neo4j_dict_data['statement_id'] = statement_id
else:
# 处理original_memory_id它可能是字符串或字典
try:
for key, value in resolved_memory.items():
if key == 'statement_id':
neo4j_dict_data['statement_id'] = value
if key == 'description':
neo4j_dict_data['description'] = value
except AttributeError:
neo4j_dict_data=[]
print(neo4j_dict_data)
print(update_databases)
if neo4j_dict_data!=[]:
await update_neo4j_data(neo4j_dict_data, update_databases)
success_count += 1
print(ori_entity)
print(updata_entity)
print(100*'-')
print(ori_edge)
print(updata_edge)
expired_at_ = updat_expired_at.get('expired_at', None)
if expired_at_ is not None:
await update_neo4j_data(ori_expired_at, updat_expired_at)
success_count += 1
if ori_entity != updata_entity:
await update_neo4j_data(ori_entity, updata_entity)
success_count += 1
if ori_edge != updata_edge:
await update_neo4j_data_edge(ori_edge, updata_edge)
success_count += 1
return success_count