Merge #32 into develop from fix/memory_reflection

更新 self_reflexion.py

* fix/memory_reflection: (50 commits squashed)

  - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)

  - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 新增反思功能(检测代码/规范化程序)

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - 反思优化

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/controllers/memory_reflection_controller.py
    #      api/app/schemas/memory_reflection_schemas.py

  - 反思优化

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection

  - 统一输出

  - 统一输出

  - 统一输出

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/controllers/memory_reflection_controller.py

  - 统一输出

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection

  - 统一输出

  - 反思速度提升,从4分钟优化成1分10-40秒

  - 反思速度提升,从4分钟优化成1分10-40秒

  - 反思速度提升,从4分钟优化成1分10-40秒

  - Merge branch develop into fix/memory_reflection (Conflict resolved online)
    
    
    # Conflicts:  
    #      api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 反思速度提升,从4分钟优化成1分10-40秒

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
    
    # Conflicts:
    #	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 更新 self_reflexion.py

  - 反思图谱添加边的修改

  - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
    
    # Conflicts:
    #	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - 反思图谱添加边的修改

  - update
    
    
    # Conflicts:  
    #      api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
    #      api/app/core/memory/utils/prompt/prompts/reflexion.jinja2

Signed-off-by: aliyun8644380055 <accounts_68c0f5d519f260d93ee2997e@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/32
This commit is contained in:
李新月
2025-12-22 12:26:45 +00:00
committed by 孙科
parent 6d41a12c8f
commit cd644b6eab
6 changed files with 254 additions and 130 deletions

View File

@@ -18,17 +18,10 @@ from enum import Enum
import uuid
from pydantic import BaseModel
from app.core.response_utils import success
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
from app.repositories.neo4j.neo4j_update import neo4j_data
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.config.get_data import get_data
from app.core.memory.utils.config.get_data import get_data_statement
from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
@@ -45,7 +38,6 @@ from app.repositories.neo4j.neo4j_update import neo4j_data
from app.schemas.memory_storage_schema import ConflictResultSchema
from app.schemas.memory_storage_schema import ReflexionResultSchema
# 配置日志
_root_logger = logging.getLogger()
if not _root_logger.handlers:
@@ -241,8 +233,7 @@ class ReflectionEngine:
print(100 * '-')
print(conflict_data)
print(100 * '-')
# 检查是否真的有冲突
# # 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
@@ -270,7 +261,7 @@ class ReflectionEngine:
await self._log_data("solved_data", solved_data)
# 4. 应用反思结果(更新记忆库)
memories_updated = await self._apply_reflection_results(solved_data)
memories_updated=await self._apply_reflection_results(solved_data)
execution_time = asyncio.get_event_loop().time() - start_time
@@ -297,6 +288,7 @@ class ReflectionEngine:
async def reflection_run(self):
self._lazy_init()
start_time = time.time()
memory_verifies_flag = self.config.memory_verify
asyncio.get_event_loop().time()
logging.info("====== 自我反思流程开始 ======")
@@ -316,8 +308,8 @@ class ReflectionEngine:
for item in conflict_data:
quality_assessments.append(item['quality_assessment'])
memory_verifies.append(item['memory_verify'])
result_data['quality_assessments'] = quality_assessments
result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments
# 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
@@ -354,6 +346,9 @@ class ReflectionEngine:
for result in item['results']:
reflexion_data.append(result['reflexion'])
result_data['reflexion_data'] = reflexion_data
if memory_verifies_flag==False:
result_data['memory_verifies']=[None]
print(time.time()-start_time,'----------')
return result_data
@@ -561,7 +556,8 @@ class ReflectionEngine:
Returns:
int: 成功更新的记忆数量
"""
success_count = await neo4j_data(solved_data)
changes = extract_and_process_changes(solved_data)
success_count = await neo4j_data(changes)
return success_count
async def _log_data(self, label: str, data: Any) -> None:
@@ -668,4 +664,4 @@ class ReflectionEngine:
execution_time=time_result.execution_time + fact_result.execution_time
)
else:
raise ValueError(f"未知的反思基线: {self.config.baseline}")
raise ValueError(f"未知的反思基线: {self.config.baseline}")

View File

@@ -3,6 +3,20 @@ import uuid
import logging
from typing import List, Dict, Any
from openai import BaseModel
import json
import sys
from pathlib import Path
from pydantic import model_validator, Field
from app.schemas.memory_storage_schema import SingleReflexionResultSchema
from app.schemas.memory_storage_schema import ReflexionResultSchema
from app.repositories.neo4j.neo4j_update import map_field_names
# 添加项目根目录到 Python 路径
sys.path.append(str(Path(__file__).parent))
logger = logging.getLogger(__name__)
async def _load_(data: List[Any]) -> List[Dict]:
@@ -59,6 +73,14 @@ async def get_data(result):
"""
从数据库中获取数据
"""
EXCLUDE_FIELDS = {
"user_id",
"group_id",
"entity_type",
"connect_strength",
"relationship_type",
"apply_id"
}
neo4j_databasets=[]
for item in result:
filtered_item = {}
@@ -73,14 +95,17 @@ async def get_data(result):
rel_filtered['statement_id'] = value.get('statement_id')
rel_filtered['expired_at'] = value.get('expired_at')
rel_filtered['created_at'] = value.get('created_at')
filtered_item[key] = rel_filtered
filtered_item[key] = value
elif key == 'entity2' and value is not None:
# 过滤entity2的name_embedding字段
entity2_filtered = {}
if hasattr(value, 'items'):
for e_key, e_value in value.items():
if 'name_embedding' not in e_key.lower():
entity2_filtered[e_key] = e_value
if e_key in EXCLUDE_FIELDS:
continue
if 'name_embedding' in e_key.lower():
continue
entity2_filtered[e_key] = e_value
filtered_item[key] = entity2_filtered
else:
filtered_item[key] = value
@@ -94,8 +119,57 @@ async def get_data_statement( result):
neo4j_databasets.append(i)
return neo4j_databasets
class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before")
def _normalize_resolved(cls, v):
if isinstance(v, dict):
conflict = v.get("conflict")
if isinstance(conflict, dict) and conflict.get("conflict") is False:
v["resolved"] = None
else:
resolved = v.get("resolved")
if isinstance(resolved, dict):
orig = resolved.get("original_memory_id")
mem = resolved.get("resolved_memory")
if orig is None and (mem is None or mem == {}):
v["resolved"] = None
return v
def extract_and_process_changes(DATA):
"""提取并处理 change 字段"""
all_changes = []
for i, item in enumerate(DATA):
try:
result = ReflexionResultSchema(**item)
for j, res in enumerate(result.results):
if res.resolved and res.resolved.change:
for k, change in enumerate(res.resolved.change):
change_data = {}
for field_item in change.field:
for key, value in field_item.items():
change_data[key] = value
if isinstance(value, list):
print(f" - {key}: {value[0]} -> {value[1]}")
else:
print(f" - {key}: {value}")
all_changes.append({
'data': change_data
})
# 测试字段映射
try:
mapped = map_field_names(change_data)
print(f" 映射结果: {mapped}")
except Exception as e:
print(f" 映射失败: {e}")
except Exception as e:
print(f"处理结果 {i + 1} 失败: {e}")
return all_changes
if __name__ == "__main__":
import asyncio

View File

@@ -61,7 +61,7 @@
- 微信号: user123456 → use****3456
- 邮箱: zhang.san@example.com → zha****@example.com
**脱敏字段**: name、entity1_name、entity2_name、description
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
## 4. 处理流程
@@ -97,21 +97,11 @@
### 处理规则
**情况1: 正确答案存在于data中**
- 保留正确记录不变
- 时间冲突: 修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
- 事实冲突: 同样处理
- resolved.resolved_memory只包含被设为失效的错误记录
- change字段只记录expired_at变更: `[{"expired_at": "2025-12-16T12:00:00"}]`
- 注意: 如果已存在时间则不需要修改
**情况2: 正确答案不存在于data中**
- 选择最合适记录进行修改
- 更新相关字段:
- description字段: 添加或修改描述信息{% if memory_verify %}(含隐私信息需脱敏){% endif %}
- name字段: 修改名称字段{% if memory_verify %}(含隐私信息需脱敏){% endif %}
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
- change字段记录所有被修改字段{% if memory_verify %},包括脱敏变更{% endif %}
** baseline是TIME
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
** baseline不是TIME
- 修改字段内容( name、entity1_name、entity2_name、description、relationship字段内容是否正确如果不正确需要对这些字段的内容重新生成则不需要修改expired_at字段,
如果涉及到修改entity1_name/entity2_name字段的时候同时也需要修改description字段输出修改前和修改后的放入change里面的field
**核心原则**:
- 只输出需要修改的记录
@@ -126,8 +116,10 @@
"change": [
{
"field": [
{"字段名1": "修改后的值1"},
{"字段名2": "修改后的值2"}
{"id":修改字段对应的ID}
{"statement_id":需要修改的对象对应的statement_id}
{"字段名1": ["修改前的值1","修改后的值1"]},
{"字段名2": ["修改前的值2","修改后的值2"]}
]
}
]
@@ -149,7 +141,8 @@
**嵌套字段映射**(系统自动处理):
- `entity2.name` → 自动映射为 `name`
- `entity1.name` → 自动映射为 `name`
- `entity1.name` → 自动映射为 `name`
- `relationship` → 自动映射为 `statement`
- `entity1.description` → 自动映射为 `description`
- `entity2.description` → 自动映射为 `description`