Merge #32 into develop from fix/memory_reflection
更新 self_reflexion.py
* fix/memory_reflection: (50 commits squashed)
- 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)
- 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期)
- 新增反思功能(检测代码/规范化程序)
- 新增反思功能(检测代码/规范化程序)
- 新增反思功能(检测代码/规范化程序)
- 新增反思功能(检测代码/规范化程序)
- 新增反思功能(检测代码/规范化程序)
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- 反思优化
- Merge branch develop into fix/memory_reflection (Conflict resolved online)
# Conflicts:
# api/app/controllers/memory_reflection_controller.py
# api/app/schemas/memory_reflection_schemas.py
- 反思优化
- Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
- 统一输出
- 统一输出
- 统一输出
- Merge branch develop into fix/memory_reflection (Conflict resolved online)
# Conflicts:
# api/app/controllers/memory_reflection_controller.py
- 统一输出
- Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
- 统一输出
- 反思速度提升,从4分钟优化成1分10-40秒
- 反思速度提升,从4分钟优化成1分10-40秒
- 反思速度提升,从4分钟优化成1分10-40秒
- Merge branch develop into fix/memory_reflection (Conflict resolved online)
# Conflicts:
# api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
- 反思速度提升,从4分钟优化成1分10-40秒
- Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
# Conflicts:
#	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
- 更新 self_reflexion.py
- 反思图谱添加边的修改
- Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection
# Conflicts:
#	api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
- 反思图谱添加边的修改
- 反思图谱添加边的修改
- 反思图谱添加边的修改
- 反思图谱添加边的修改
- 反思图谱添加边的修改
- update
# Conflicts:
# api/app/core/memory/storage_services/reflection_engine/self_reflexion.py
# api/app/core/memory/utils/prompt/prompts/reflexion.jinja2
Signed-off-by: aliyun8644380055 <accounts_68c0f5d519f260d93ee2997e@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/32
This commit is contained in:
@@ -18,17 +18,10 @@ from enum import Enum
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
|
||||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
|
||||||
|
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.memory.utils.config import definitions as config_defs
|
from app.core.memory.utils.config import definitions as config_defs
|
||||||
from app.core.memory.utils.config import get_model_config
|
from app.core.memory.utils.config import get_model_config
|
||||||
from app.core.memory.utils.config.get_data import get_data
|
from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes
|
||||||
from app.core.memory.utils.config.get_data import get_data_statement
|
|
||||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||||
@@ -45,7 +38,6 @@ from app.repositories.neo4j.neo4j_update import neo4j_data
|
|||||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||||
|
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
_root_logger = logging.getLogger()
|
_root_logger = logging.getLogger()
|
||||||
if not _root_logger.handlers:
|
if not _root_logger.handlers:
|
||||||
@@ -241,8 +233,7 @@ class ReflectionEngine:
|
|||||||
print(100 * '-')
|
print(100 * '-')
|
||||||
print(conflict_data)
|
print(conflict_data)
|
||||||
print(100 * '-')
|
print(100 * '-')
|
||||||
|
# # 检查是否真的有冲突
|
||||||
# 检查是否真的有冲突
|
|
||||||
has_conflict = conflict_data[0].get('conflict', False)
|
has_conflict = conflict_data[0].get('conflict', False)
|
||||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||||
@@ -270,7 +261,7 @@ class ReflectionEngine:
|
|||||||
await self._log_data("solved_data", solved_data)
|
await self._log_data("solved_data", solved_data)
|
||||||
|
|
||||||
# 4. 应用反思结果(更新记忆库)
|
# 4. 应用反思结果(更新记忆库)
|
||||||
memories_updated = await self._apply_reflection_results(solved_data)
|
memories_updated=await self._apply_reflection_results(solved_data)
|
||||||
|
|
||||||
execution_time = asyncio.get_event_loop().time() - start_time
|
execution_time = asyncio.get_event_loop().time() - start_time
|
||||||
|
|
||||||
@@ -297,6 +288,7 @@ class ReflectionEngine:
|
|||||||
async def reflection_run(self):
|
async def reflection_run(self):
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
memory_verifies_flag = self.config.memory_verify
|
||||||
|
|
||||||
asyncio.get_event_loop().time()
|
asyncio.get_event_loop().time()
|
||||||
logging.info("====== 自我反思流程开始 ======")
|
logging.info("====== 自我反思流程开始 ======")
|
||||||
@@ -316,8 +308,8 @@ class ReflectionEngine:
|
|||||||
for item in conflict_data:
|
for item in conflict_data:
|
||||||
quality_assessments.append(item['quality_assessment'])
|
quality_assessments.append(item['quality_assessment'])
|
||||||
memory_verifies.append(item['memory_verify'])
|
memory_verifies.append(item['memory_verify'])
|
||||||
result_data['quality_assessments'] = quality_assessments
|
|
||||||
result_data['memory_verifies'] = memory_verifies
|
result_data['memory_verifies'] = memory_verifies
|
||||||
|
result_data['quality_assessments'] = quality_assessments
|
||||||
|
|
||||||
# 检查是否真的有冲突
|
# 检查是否真的有冲突
|
||||||
has_conflict = conflict_data[0].get('conflict', False)
|
has_conflict = conflict_data[0].get('conflict', False)
|
||||||
@@ -354,6 +346,9 @@ class ReflectionEngine:
|
|||||||
for result in item['results']:
|
for result in item['results']:
|
||||||
reflexion_data.append(result['reflexion'])
|
reflexion_data.append(result['reflexion'])
|
||||||
result_data['reflexion_data'] = reflexion_data
|
result_data['reflexion_data'] = reflexion_data
|
||||||
|
if memory_verifies_flag==False:
|
||||||
|
result_data['memory_verifies']=[None]
|
||||||
|
print(time.time()-start_time,'----------')
|
||||||
return result_data
|
return result_data
|
||||||
|
|
||||||
|
|
||||||
@@ -561,7 +556,8 @@ class ReflectionEngine:
|
|||||||
Returns:
|
Returns:
|
||||||
int: 成功更新的记忆数量
|
int: 成功更新的记忆数量
|
||||||
"""
|
"""
|
||||||
success_count = await neo4j_data(solved_data)
|
changes = extract_and_process_changes(solved_data)
|
||||||
|
success_count = await neo4j_data(changes)
|
||||||
return success_count
|
return success_count
|
||||||
|
|
||||||
async def _log_data(self, label: str, data: Any) -> None:
|
async def _log_data(self, label: str, data: Any) -> None:
|
||||||
@@ -668,4 +664,4 @@ class ReflectionEngine:
|
|||||||
execution_time=time_result.execution_time + fact_result.execution_time
|
execution_time=time_result.execution_time + fact_result.execution_time
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||||
@@ -3,6 +3,20 @@ import uuid
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from openai import BaseModel
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import model_validator, Field
|
||||||
|
|
||||||
|
from app.schemas.memory_storage_schema import SingleReflexionResultSchema
|
||||||
|
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||||
|
from app.repositories.neo4j.neo4j_update import map_field_names
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def _load_(data: List[Any]) -> List[Dict]:
|
async def _load_(data: List[Any]) -> List[Dict]:
|
||||||
@@ -59,6 +73,14 @@ async def get_data(result):
|
|||||||
"""
|
"""
|
||||||
从数据库中获取数据
|
从数据库中获取数据
|
||||||
"""
|
"""
|
||||||
|
EXCLUDE_FIELDS = {
|
||||||
|
"user_id",
|
||||||
|
"group_id",
|
||||||
|
"entity_type",
|
||||||
|
"connect_strength",
|
||||||
|
"relationship_type",
|
||||||
|
"apply_id"
|
||||||
|
}
|
||||||
neo4j_databasets=[]
|
neo4j_databasets=[]
|
||||||
for item in result:
|
for item in result:
|
||||||
filtered_item = {}
|
filtered_item = {}
|
||||||
@@ -73,14 +95,17 @@ async def get_data(result):
|
|||||||
rel_filtered['statement_id'] = value.get('statement_id')
|
rel_filtered['statement_id'] = value.get('statement_id')
|
||||||
rel_filtered['expired_at'] = value.get('expired_at')
|
rel_filtered['expired_at'] = value.get('expired_at')
|
||||||
rel_filtered['created_at'] = value.get('created_at')
|
rel_filtered['created_at'] = value.get('created_at')
|
||||||
filtered_item[key] = rel_filtered
|
filtered_item[key] = value
|
||||||
elif key == 'entity2' and value is not None:
|
elif key == 'entity2' and value is not None:
|
||||||
# 过滤entity2的name_embedding字段
|
# 过滤entity2的name_embedding字段
|
||||||
entity2_filtered = {}
|
entity2_filtered = {}
|
||||||
if hasattr(value, 'items'):
|
if hasattr(value, 'items'):
|
||||||
for e_key, e_value in value.items():
|
for e_key, e_value in value.items():
|
||||||
if 'name_embedding' not in e_key.lower():
|
if e_key in EXCLUDE_FIELDS:
|
||||||
entity2_filtered[e_key] = e_value
|
continue
|
||||||
|
if 'name_embedding' in e_key.lower():
|
||||||
|
continue
|
||||||
|
entity2_filtered[e_key] = e_value
|
||||||
filtered_item[key] = entity2_filtered
|
filtered_item[key] = entity2_filtered
|
||||||
else:
|
else:
|
||||||
filtered_item[key] = value
|
filtered_item[key] = value
|
||||||
@@ -94,8 +119,57 @@ async def get_data_statement( result):
|
|||||||
neo4j_databasets.append(i)
|
neo4j_databasets.append(i)
|
||||||
return neo4j_databasets
|
return neo4j_databasets
|
||||||
|
|
||||||
|
class ReflexionResultSchema(BaseModel):
|
||||||
|
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||||
|
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def _normalize_resolved(cls, v):
|
||||||
|
if isinstance(v, dict):
|
||||||
|
conflict = v.get("conflict")
|
||||||
|
if isinstance(conflict, dict) and conflict.get("conflict") is False:
|
||||||
|
v["resolved"] = None
|
||||||
|
else:
|
||||||
|
resolved = v.get("resolved")
|
||||||
|
if isinstance(resolved, dict):
|
||||||
|
orig = resolved.get("original_memory_id")
|
||||||
|
mem = resolved.get("resolved_memory")
|
||||||
|
if orig is None and (mem is None or mem == {}):
|
||||||
|
v["resolved"] = None
|
||||||
|
return v
|
||||||
|
def extract_and_process_changes(DATA):
|
||||||
|
"""提取并处理 change 字段"""
|
||||||
|
all_changes = []
|
||||||
|
for i, item in enumerate(DATA):
|
||||||
|
try:
|
||||||
|
result = ReflexionResultSchema(**item)
|
||||||
|
for j, res in enumerate(result.results):
|
||||||
|
if res.resolved and res.resolved.change:
|
||||||
|
for k, change in enumerate(res.resolved.change):
|
||||||
|
change_data = {}
|
||||||
|
for field_item in change.field:
|
||||||
|
for key, value in field_item.items():
|
||||||
|
change_data[key] = value
|
||||||
|
if isinstance(value, list):
|
||||||
|
print(f" - {key}: {value[0]} -> {value[1]}")
|
||||||
|
else:
|
||||||
|
print(f" - {key}: {value}")
|
||||||
|
|
||||||
|
all_changes.append({
|
||||||
|
'data': change_data
|
||||||
|
})
|
||||||
|
|
||||||
|
# 测试字段映射
|
||||||
|
try:
|
||||||
|
mapped = map_field_names(change_data)
|
||||||
|
print(f" 映射结果: {mapped}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 映射失败: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理结果 {i + 1} 失败: {e}")
|
||||||
|
|
||||||
|
return all_changes
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|||||||
@@ -61,7 +61,7 @@
|
|||||||
- 微信号: user123456 → use****3456
|
- 微信号: user123456 → use****3456
|
||||||
- 邮箱: zhang.san@example.com → zha****@example.com
|
- 邮箱: zhang.san@example.com → zha****@example.com
|
||||||
|
|
||||||
**脱敏字段**: name、entity1_name、entity2_name、description
|
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||||
|
|
||||||
## 4. 处理流程
|
## 4. 处理流程
|
||||||
|
|
||||||
@@ -97,21 +97,11 @@
|
|||||||
|
|
||||||
### 处理规则
|
### 处理规则
|
||||||
|
|
||||||
**情况1: 正确答案存在于data中**
|
** baseline是TIME
|
||||||
- 保留正确记录不变
|
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
||||||
- 时间冲突: 修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
** baseline不是TIME
|
||||||
- 事实冲突: 同样处理
|
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||||
- change字段只记录expired_at变更: `[{"expired_at": "2025-12-16T12:00:00"}]`
|
|
||||||
- 注意: 如果已存在时间则不需要修改
|
|
||||||
|
|
||||||
**情况2: 正确答案不存在于data中**
|
|
||||||
- 选择最合适记录进行修改
|
|
||||||
- 更新相关字段:
|
|
||||||
- description字段: 添加或修改描述信息{% if memory_verify %}(含隐私信息需脱敏){% endif %}
|
|
||||||
- name字段: 修改名称字段{% if memory_verify %}(含隐私信息需脱敏){% endif %}
|
|
||||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
|
||||||
- change字段记录所有被修改字段{% if memory_verify %},包括脱敏变更{% endif %}
|
|
||||||
|
|
||||||
**核心原则**:
|
**核心原则**:
|
||||||
- 只输出需要修改的记录
|
- 只输出需要修改的记录
|
||||||
@@ -126,8 +116,10 @@
|
|||||||
"change": [
|
"change": [
|
||||||
{
|
{
|
||||||
"field": [
|
"field": [
|
||||||
{"字段名1": "修改后的值1"},
|
{"id":修改字段对应的ID}
|
||||||
{"字段名2": "修改后的值2"}
|
{"statement_id":需要修改的对象对应的statement_id}
|
||||||
|
{"字段名1": ["修改前的值1","修改后的值1"]},
|
||||||
|
{"字段名2": ["修改前的值2","修改后的值2"]}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -149,7 +141,8 @@
|
|||||||
|
|
||||||
**嵌套字段映射**(系统自动处理):
|
**嵌套字段映射**(系统自动处理):
|
||||||
- `entity2.name` → 自动映射为 `name`
|
- `entity2.name` → 自动映射为 `name`
|
||||||
- `entity1.name` → 自动映射为 `name`
|
- `entity1.name` → 自动映射为 `name`
|
||||||
|
- `relationship` → 自动映射为 `statement`
|
||||||
- `entity1.description` → 自动映射为 `description`
|
- `entity1.description` → 自动映射为 `description`
|
||||||
- `entity2.description` → 自动映射为 `description`
|
- `entity2.description` → 自动映射为 `description`
|
||||||
|
|
||||||
|
|||||||
@@ -783,7 +783,9 @@ neo4j_query_part = """
|
|||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
m.expired_at as expired_at,
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
rel as relationship,
|
rel.predicate as predicate,
|
||||||
|
rel.statement as relationship,
|
||||||
|
rel.statement_id as relationship_statement_id,
|
||||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||||
other as entity2
|
other as entity2
|
||||||
"""
|
"""
|
||||||
@@ -799,7 +801,9 @@ neo4j_query_all = """
|
|||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
m.expired_at as expired_at,
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
rel as relationship,
|
rel.predicate as predicate,
|
||||||
|
rel.statement as relationship,
|
||||||
|
rel.statement_id as relationship_statement_id,
|
||||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||||
other as entity2
|
other as entity2
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -67,11 +67,81 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
|
||||||
|
"""
|
||||||
|
Update Neo4j data based on query criteria and update parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
neo4j_dict_data: find
|
||||||
|
update_databases: update
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建WHERE条件
|
||||||
|
where_conditions = []
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
for key, value in neo4j_dict_data.items():
|
||||||
|
if value is not None:
|
||||||
|
param_name = f"param_{key}"
|
||||||
|
where_conditions.append(f"r.{key} = ${param_name}")
|
||||||
|
params[param_name] = value
|
||||||
|
|
||||||
|
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||||
|
|
||||||
|
# 构建SET条件
|
||||||
|
set_conditions = []
|
||||||
|
for key, value in update_databases.items():
|
||||||
|
if value is not None:
|
||||||
|
param_name = f"update_{key}"
|
||||||
|
set_conditions.append(f"r.{key} = ${param_name}")
|
||||||
|
params[param_name] = value
|
||||||
|
|
||||||
|
set_clause = ", ".join(set_conditions)
|
||||||
|
|
||||||
|
if not set_clause:
|
||||||
|
print("警告: 没有需要更新的字段")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 构建Cypher查询
|
||||||
|
cypher_query = f"""
|
||||||
|
MATCH (n)-[r]->(m)
|
||||||
|
WHERE {where_clause}
|
||||||
|
SET {set_clause}
|
||||||
|
RETURN count(r) as updated_count, collect(type(r)) as relation_types
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(f"\n执行Cypher查询: {cypher_query}")
|
||||||
|
print(f"参数: {params}")
|
||||||
|
|
||||||
|
# 执行更新
|
||||||
|
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
updated_count = result[0].get('updated_count', 0)
|
||||||
|
updated_names = result[0].get('updated_names', [])
|
||||||
|
print(f"成功更新 {updated_count} 个节点")
|
||||||
|
if updated_names:
|
||||||
|
print(f"更新的实体名称: {updated_names}")
|
||||||
|
return updated_count > 0
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"更新过程中出现错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
def map_field_names(data_dict):
|
def map_field_names(data_dict):
|
||||||
mapped_dict = {}
|
mapped_dict = {}
|
||||||
has_name_field = False
|
has_name_field = False
|
||||||
|
|
||||||
|
# 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回)
|
||||||
|
def extract_value(value):
|
||||||
|
if isinstance(value, list) and len(value) > 0:
|
||||||
|
# 如果是数组 [old_value, new_value],取新值(最后一个)
|
||||||
|
return value[-1]
|
||||||
|
return value
|
||||||
|
|
||||||
# 第一遍:检查是否有name相关字段
|
# 第一遍:检查是否有name相关字段
|
||||||
for key, value in data_dict.items():
|
for key, value in data_dict.items():
|
||||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||||
@@ -82,22 +152,25 @@ def map_field_names(data_dict):
|
|||||||
|
|
||||||
# 第二遍:根据规则映射和过滤字段
|
# 第二遍:根据规则映射和过滤字段
|
||||||
for key, value in data_dict.items():
|
for key, value in data_dict.items():
|
||||||
|
# 提取实际值(处理数组格式)
|
||||||
|
actual_value = extract_value(value)
|
||||||
|
|
||||||
if key == 'entity2.name' or key == 'entity2_name':
|
if key == 'entity2.name' or key == 'entity2_name':
|
||||||
# 将 entity2.name 映射为 name
|
# 将 entity2.name 映射为 name
|
||||||
mapped_dict['name'] = value
|
mapped_dict['name'] = actual_value
|
||||||
print(f"字段名映射: {key} -> name")
|
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||||
elif key == 'entity1.name' or key == 'entity1_name':
|
elif key == 'entity1.name' or key == 'entity1_name':
|
||||||
# 将 entity1.name 映射为 name
|
# 将 entity1.name 映射为 name
|
||||||
mapped_dict['name'] = value
|
mapped_dict['name'] = actual_value
|
||||||
print(f"字段名映射: {key} -> name")
|
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||||
elif key == 'entity1.description':
|
elif key == 'entity1.description':
|
||||||
# 将 entity1.description 映射为 description
|
# 将 entity1.description 映射为 description
|
||||||
mapped_dict['description'] = value
|
mapped_dict['description'] = actual_value
|
||||||
print(f"字段名映射: {key} -> description")
|
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||||
elif key == 'entity2.description':
|
elif key == 'entity2.description':
|
||||||
# 将 entity2.description 映射为 description
|
# 将 entity2.description 映射为 description
|
||||||
mapped_dict['description'] = value
|
mapped_dict['description'] = actual_value
|
||||||
print(f"字段名映射: {key} -> description")
|
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||||
elif key == 'relationship_type':
|
elif key == 'relationship_type':
|
||||||
# 跳过relationship_type字段
|
# 跳过relationship_type字段
|
||||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||||
@@ -109,8 +182,8 @@ def map_field_names(data_dict):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# 如果没有name字段,保留entity1_name
|
# 如果没有name字段,保留entity1_name
|
||||||
mapped_dict[key] = value
|
mapped_dict[key] = actual_value
|
||||||
print(f"字段保留: {key}")
|
print(f"字段保留: {key} (值: {value} -> {actual_value})")
|
||||||
elif key == 'entity2_name':
|
elif key == 'entity2_name':
|
||||||
if has_name_field:
|
if has_name_field:
|
||||||
# 如果有name字段,跳过entity2_name
|
# 如果有name字段,跳过entity2_name
|
||||||
@@ -122,7 +195,11 @@ def map_field_names(data_dict):
|
|||||||
continue
|
continue
|
||||||
elif '.' not in key:
|
elif '.' not in key:
|
||||||
# 不包含点号的其他字段直接保留
|
# 不包含点号的其他字段直接保留
|
||||||
mapped_dict[key] = value
|
mapped_dict[key] = actual_value
|
||||||
|
if isinstance(value, list):
|
||||||
|
print(f"字段保留: {key} (数组值: {value} -> {actual_value})")
|
||||||
|
else:
|
||||||
|
print(f"字段保留: {key}")
|
||||||
else:
|
else:
|
||||||
# 其他包含点号的字段跳过并警告
|
# 其他包含点号的字段跳过并警告
|
||||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||||
@@ -139,89 +216,57 @@ async def neo4j_data(solved_data):
|
|||||||
"""
|
"""
|
||||||
success_count = 0
|
success_count = 0
|
||||||
|
|
||||||
|
ori_entity = {}
|
||||||
|
updata_entity = {}
|
||||||
|
ori_edge = {}
|
||||||
|
updata_edge = {}
|
||||||
|
ori_expired_at={}
|
||||||
|
updat_expired_at={}
|
||||||
for i in solved_data:
|
for i in solved_data:
|
||||||
neo4j_dict_data = {}
|
databasets = i['data']
|
||||||
update_databases = {}
|
for key, values in databasets.items():
|
||||||
results = i['results']
|
if str(values)=='NONE':
|
||||||
for data in results:
|
|
||||||
resolved = data.get('resolved')
|
|
||||||
if not resolved:
|
|
||||||
print("跳过:resolved为None")
|
|
||||||
continue
|
continue
|
||||||
|
if isinstance(values, list):
|
||||||
|
if key == 'description':
|
||||||
|
ori_entity[key] = values[0]
|
||||||
|
updata_entity[key] = values[1]
|
||||||
|
if key == 'entity2_name' or key == 'entity1_name':
|
||||||
|
key = 'name'
|
||||||
|
ori_entity[key] = values[0]
|
||||||
|
updata_entity[key] = values[1]
|
||||||
|
ori_expired_at[key] = values[0]
|
||||||
|
if key == 'statement':
|
||||||
|
ori_edge[key] = values[0]
|
||||||
|
updata_edge[key] = values[1]
|
||||||
|
if key=='expired_at':
|
||||||
|
updat_expired_at[key] = values[1]
|
||||||
|
|
||||||
try:
|
elif key == 'statement_id':
|
||||||
change_list = resolved.get('change', [])
|
ori_edge[key] = values
|
||||||
except (AttributeError, TypeError):
|
updata_edge[key] = values
|
||||||
change_list = []
|
|
||||||
|
|
||||||
if change_list == []:
|
ori_entity[key] = values
|
||||||
print("跳过:change_list为空")
|
updata_entity[key] = values
|
||||||
continue
|
|
||||||
|
|
||||||
if change_list and len(change_list) > 0:
|
ori_expired_at[key] = values
|
||||||
change = change_list[0]
|
|
||||||
print(f"change: {change}")
|
|
||||||
field_data = change.get('field', [])
|
|
||||||
print(f"field_data: {field_data}")
|
|
||||||
print(f"field_data type: {type(field_data)}")
|
|
||||||
|
|
||||||
# 字段名映射和过滤函数
|
|
||||||
|
|
||||||
|
|
||||||
# 处理field数据,可能是字典或列表
|
print(ori_entity)
|
||||||
if isinstance(field_data, dict):
|
print(updata_entity)
|
||||||
# 如果是字典,映射字段名后更新
|
print(100*'-')
|
||||||
mapped_data = map_field_names(field_data)
|
print(ori_edge)
|
||||||
update_databases.update(mapped_data)
|
print(updata_edge)
|
||||||
elif isinstance(field_data, list):
|
expired_at_ = updat_expired_at.get('expired_at', None)
|
||||||
# 如果是列表,遍历每个字典并更新
|
if expired_at_ is not None:
|
||||||
for field_item in field_data:
|
await update_neo4j_data(ori_expired_at, updat_expired_at)
|
||||||
if isinstance(field_item, dict):
|
success_count += 1
|
||||||
mapped_item = map_field_names(field_item)
|
if ori_entity != updata_entity:
|
||||||
update_databases.update(mapped_item)
|
await update_neo4j_data(ori_entity, updata_entity)
|
||||||
else:
|
success_count += 1
|
||||||
print(f"警告: field_item不是字典: {field_item}")
|
if ori_edge != updata_edge:
|
||||||
else:
|
await update_neo4j_data_edge(ori_edge, updata_edge)
|
||||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
success_count += 1
|
||||||
|
|
||||||
if 'entity1_name' in data:
|
|
||||||
data['name'] = data.pop('entity1_name')
|
|
||||||
if 'entity2_name' in data:
|
|
||||||
data.pop('entity2_name', None)
|
|
||||||
|
|
||||||
resolved_memory = resolved.get('resolved_memory', {})
|
|
||||||
|
|
||||||
entity2 = None
|
|
||||||
if isinstance(resolved_memory, dict):
|
|
||||||
entity2 = resolved_memory.get('entity2')
|
|
||||||
|
|
||||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
|
||||||
stat_id = resolved.get('original_memory_id')
|
|
||||||
# 安全地获取description
|
|
||||||
statement_id = None
|
|
||||||
if isinstance(resolved_memory, dict):
|
|
||||||
statement_id = resolved_memory.get('statement_id')
|
|
||||||
|
|
||||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
|
||||||
if statement_id and 'id' not in neo4j_dict_data:
|
|
||||||
neo4j_dict_data['id'] = stat_id
|
|
||||||
neo4j_dict_data['statement_id'] = statement_id
|
|
||||||
else:
|
|
||||||
# 处理original_memory_id,它可能是字符串或字典
|
|
||||||
try:
|
|
||||||
for key, value in resolved_memory.items():
|
|
||||||
if key == 'statement_id':
|
|
||||||
neo4j_dict_data['statement_id'] = value
|
|
||||||
if key == 'description':
|
|
||||||
neo4j_dict_data['description'] = value
|
|
||||||
except AttributeError:
|
|
||||||
neo4j_dict_data=[]
|
|
||||||
|
|
||||||
print(neo4j_dict_data)
|
|
||||||
print(update_databases)
|
|
||||||
if neo4j_dict_data!=[]:
|
|
||||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
|
||||||
success_count += 1
|
|
||||||
|
|
||||||
return success_count
|
return success_count
|
||||||
|
|
||||||
|
|||||||
@@ -39,8 +39,11 @@ class BaseDataSchema(BaseModel):
|
|||||||
entity1_name: str = Field(..., description="The first entity name.")
|
entity1_name: str = Field(..., description="The first entity name.")
|
||||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||||
statement_id: str = Field(..., description="The statement identifier.")
|
statement_id: str = Field(..., description="The statement identifier.")
|
||||||
relationship_type: str = Field(..., description="The relationship type.")
|
# 新增字段 - 设为可选以保持向后兼容性
|
||||||
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
|
predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.")
|
||||||
|
relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.")
|
||||||
|
# 保留原有字段 - 修改relationship字段类型以支持字符串和字典
|
||||||
|
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
|
||||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||||
|
|
||||||
|
|
||||||
@@ -94,8 +97,17 @@ class ReflexionSchema(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChangeRecordSchema(BaseModel):
|
class ChangeRecordSchema(BaseModel):
|
||||||
"""Schema for individual change records"""
|
"""Schema for individual change records
|
||||||
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
|
|
||||||
|
字段值格式说明:
|
||||||
|
- id 和 statement_id: 字符串或 None
|
||||||
|
- 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构
|
||||||
|
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
|
||||||
|
"""
|
||||||
|
field: List[Dict[str, Any]] = Field(
|
||||||
|
...,
|
||||||
|
description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||||
|
)
|
||||||
|
|
||||||
class ResolvedSchema(BaseModel):
|
class ResolvedSchema(BaseModel):
|
||||||
"""Schema for the resolved memory data in the reflexion_data"""
|
"""Schema for the resolved memory data in the reflexion_data"""
|
||||||
|
|||||||
Reference in New Issue
Block a user