反思优化1.0(优化隐私输出、时间检索)
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||||
ReflectionConfig,
|
ReflectionConfig,
|
||||||
ReflectionEngine,
|
ReflectionEngine, ReflectionRange, ReflectionBaseline,
|
||||||
)
|
)
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -39,9 +40,6 @@ async def save_reflection_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Save reflection configuration to data_comfig table"""
|
"""Save reflection configuration to data_comfig table"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
if not config_id:
|
if not config_id:
|
||||||
@@ -52,51 +50,30 @@ async def save_reflection_config(
|
|||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
update_params = {
|
data_config = DataConfigRepository.update_reflection_config(
|
||||||
"enable_self_reflexion": request.reflection_enabled,
|
db,
|
||||||
"iteration_period": request.reflection_period_in_hours,
|
config_id=config_id,
|
||||||
"reflexion_range": request.reflexion_range,
|
enable_self_reflexion=request.reflection_enabled,
|
||||||
"baseline": request.baseline,
|
iteration_period=request.reflection_period_in_hours,
|
||||||
"reflection_model_id": request.reflection_model_id,
|
reflexion_range=request.reflexion_range,
|
||||||
"memory_verify": request.memory_verify,
|
baseline=request.baseline,
|
||||||
"quality_assessment": request.quality_assessment,
|
reflection_model_id=request.reflection_model_id,
|
||||||
}
|
memory_verify=request.memory_verify,
|
||||||
|
quality_assessment=request.quality_assessment
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
|
||||||
|
|
||||||
result = db.execute(text(query), params)
|
|
||||||
if result.rowcount == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"未找到config_id为 {config_id} 的配置"
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(data_config)
|
||||||
# 查询更新后的配置
|
|
||||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
|
||||||
result = db.execute(text(select_query), select_params).fetchone()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
|
||||||
|
|
||||||
reflection_result={
|
reflection_result={
|
||||||
"config_id": result.config_id,
|
"config_id": data_config.config_id,
|
||||||
"enable_self_reflexion": result.enable_self_reflexion,
|
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||||
"iteration_period": result.iteration_period,
|
"iteration_period": data_config.iteration_period,
|
||||||
"reflexion_range": result.reflexion_range,
|
"reflexion_range": data_config.reflexion_range,
|
||||||
"baseline": result.baseline,
|
"baseline": data_config.baseline,
|
||||||
"reflection_model_id": result.reflection_model_id,
|
"reflection_model_id": data_config.reflection_model_id,
|
||||||
"memory_verify": result.memory_verify,
|
"memory_verify": data_config.memory_verify,
|
||||||
"quality_assessment": result.quality_assessment,
|
"quality_assessment": data_config.quality_assessment}
|
||||||
"user_id": result.user_id}
|
|
||||||
|
|
||||||
return success(data=reflection_result, msg="反思配置成功")
|
return success(data=reflection_result, msg="反思配置成功")
|
||||||
|
|
||||||
@@ -116,9 +93,8 @@ async def save_reflection_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reflection")
|
@router.get("/reflection")
|
||||||
async def start_workspace_reflection(
|
async def start_workspace_reflection(
|
||||||
config_id: int,
|
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -178,17 +154,7 @@ async def start_reflection_configs(
|
|||||||
"""通过config_id查询data_config表中的反思配置信息"""
|
"""通过config_id查询data_config表中的反思配置信息"""
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
|
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
# 使用DataConfigRepository查询反思配置
|
|
||||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
|
||||||
result = db.execute(text(select_query), select_params).fetchone()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"未找到config_id为 {config_id} 的配置"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建返回数据
|
# 构建返回数据
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": result.config_id,
|
"config_id": result.config_id,
|
||||||
@@ -198,8 +164,7 @@ async def start_reflection_configs(
|
|||||||
"baseline": result.baseline,
|
"baseline": result.baseline,
|
||||||
"reflection_model_id": result.reflection_model_id,
|
"reflection_model_id": result.reflection_model_id,
|
||||||
"memory_verify": result.memory_verify,
|
"memory_verify": result.memory_verify,
|
||||||
"quality_assessment": result.quality_assessment,
|
"quality_assessment": result.quality_assessment
|
||||||
"user_id": result.user_id
|
|
||||||
}
|
}
|
||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
return success(data=reflection_config, msg="反思配置查询成功")
|
return success(data=reflection_config, msg="反思配置查询成功")
|
||||||
@@ -227,9 +192,7 @@ async def reflection_run(
|
|||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
# 使用DataConfigRepository查询反思配置
|
# 使用DataConfigRepository查询反思配置
|
||||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
result = db.execute(text(select_query), select_params).fetchone()
|
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
@@ -242,7 +205,7 @@ async def reflection_run(
|
|||||||
model_id = result.reflection_model_id
|
model_id = result.reflection_model_id
|
||||||
if model_id:
|
if model_id:
|
||||||
try:
|
try:
|
||||||
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id))
|
||||||
api_logger.info(f"模型ID验证成功: {model_id}")
|
api_logger.info(f"模型ID验证成功: {model_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
||||||
@@ -252,8 +215,8 @@ async def reflection_run(
|
|||||||
config = ReflectionConfig(
|
config = ReflectionConfig(
|
||||||
enabled=result.enable_self_reflexion,
|
enabled=result.enable_self_reflexion,
|
||||||
iteration_period=result.iteration_period,
|
iteration_period=result.iteration_period,
|
||||||
reflexion_range=result.reflexion_range,
|
reflexion_range=ReflectionRange(result.reflexion_range),
|
||||||
baseline=result.baseline,
|
baseline=ReflectionBaseline(result.baseline),
|
||||||
output_example='',
|
output_example='',
|
||||||
memory_verify=result.memory_verify,
|
memory_verify=result.memory_verify,
|
||||||
quality_assessment=result.quality_assessment,
|
quality_assessment=result.quality_assessment,
|
||||||
|
|||||||
@@ -24,15 +24,9 @@ from app.core.memory.utils.config.get_data import (
|
|||||||
get_data,
|
get_data,
|
||||||
get_data_statement,
|
get_data_statement,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
|
||||||
from app.core.memory.utils.prompt.template_render import (
|
|
||||||
render_evaluate_prompt,
|
|
||||||
render_reflexion_prompt,
|
|
||||||
)
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.repositories.neo4j.cypher_queries import (
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
UPDATE_STATEMENT_INVALID_AT,
|
|
||||||
neo4j_query_all,
|
neo4j_query_all,
|
||||||
neo4j_query_part,
|
neo4j_query_part,
|
||||||
neo4j_statement_all,
|
neo4j_statement_all,
|
||||||
@@ -160,12 +154,11 @@ class ReflectionEngine:
|
|||||||
self.neo4j_connector = Neo4jConnector()
|
self.neo4j_connector = Neo4jConnector()
|
||||||
|
|
||||||
if self.llm_client is None:
|
if self.llm_client is None:
|
||||||
from app.core.memory.utils.config import definitions as config_defs
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||||
elif isinstance(self.llm_client, str):
|
elif isinstance(self.llm_client, str):
|
||||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -263,25 +256,23 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
# 2. 检测冲突(基于事实的反思)
|
# 2. 检测冲突(基于事实的反思)
|
||||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||||
print(100 * '-')
|
conflict_list=[]
|
||||||
print(conflict_data)
|
for i in conflict_data:
|
||||||
print(100 * '-')
|
conflict_list.append(i['data'])
|
||||||
# # 检查是否真的有冲突
|
|
||||||
conflicts_found=''
|
|
||||||
|
|
||||||
conflicts_found=''
|
|
||||||
|
|
||||||
|
conflicts_found=0
|
||||||
# 3. 解决冲突
|
# 3. 解决冲突
|
||||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||||
|
|
||||||
if not solved_data:
|
if not solved_data:
|
||||||
return ReflectionResult(
|
return ReflectionResult(
|
||||||
success=False,
|
success=False,
|
||||||
message="反思失败,未解决冲突",
|
message=f"没有{self.config.baseline}相关的冲突数据",
|
||||||
conflicts_found=conflicts_found,
|
conflicts_found=conflicts_found,
|
||||||
execution_time=asyncio.get_event_loop().time() - start_time
|
execution_time=asyncio.get_event_loop().time() - start_time
|
||||||
)
|
)
|
||||||
print(100 * '*')
|
|
||||||
print(solved_data)
|
|
||||||
print(100 * '*')
|
|
||||||
|
|
||||||
conflicts_resolved = len(solved_data)
|
conflicts_resolved = len(solved_data)
|
||||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||||
@@ -386,7 +377,7 @@ class ReflectionEngine:
|
|||||||
memory_verifies.append(item['memory_verify'])
|
memory_verifies.append(item['memory_verify'])
|
||||||
result_data['memory_verifies'] = memory_verifies
|
result_data['memory_verifies'] = memory_verifies
|
||||||
result_data['quality_assessments'] = quality_assessments
|
result_data['quality_assessments'] = quality_assessments
|
||||||
conflicts_found=''
|
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||||
cleaned_conflict_data = []
|
cleaned_conflict_data = []
|
||||||
@@ -414,7 +405,7 @@ class ReflectionEngine:
|
|||||||
cleaned_conflict_data_.append(cleaned_item)
|
cleaned_conflict_data_.append(cleaned_item)
|
||||||
print(cleaned_conflict_data_)
|
print(cleaned_conflict_data_)
|
||||||
# 3. 解决冲突
|
# 3. 解决冲突
|
||||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||||
if not solved_data:
|
if not solved_data:
|
||||||
return ReflectionResult(
|
return ReflectionResult(
|
||||||
success=False,
|
success=False,
|
||||||
@@ -739,4 +730,3 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,8 @@
|
|||||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||||
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||||
### 混合冲突
|
### 混合冲突
|
||||||
检测所有逻辑不一致或相互矛盾的记录。
|
- 检测所有逻辑不一致或相互矛盾的记录。
|
||||||
|
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||||
**检测原则**:
|
**检测原则**:
|
||||||
- 重点检查相同实体的记录
|
- 重点检查相同实体的记录
|
||||||
- 分析description字段语义冲突
|
- 分析description字段语义冲突
|
||||||
|
|||||||
@@ -63,7 +63,7 @@
|
|||||||
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||||
|
|
||||||
## 4. 处理流程
|
## 4. 处理流程
|
||||||
|
###如果存在冲突数据执行以下步骤,不存在返回【】在data中
|
||||||
### 步骤1: 类型匹配验证
|
### 步骤1: 类型匹配验证
|
||||||
**匹配规则**:
|
**匹配规则**:
|
||||||
- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点)
|
- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点)
|
||||||
@@ -78,7 +78,7 @@
|
|||||||
|
|
||||||
### 步骤2: 冲突数据分组
|
### 步骤2: 冲突数据分组
|
||||||
**分组策略**:
|
**分组策略**:
|
||||||
- 时间冲突组: 涉及用户时间的记录
|
- 时间冲突组: 涉及用户时间的记录比如(生日在2月17...)
|
||||||
- 活动时间冲突组: 同一活动不同时间的记录
|
- 活动时间冲突组: 同一活动不同时间的记录
|
||||||
- 事实冲突组: 同一实体不同属性的记录
|
- 事实冲突组: 同一实体不同属性的记录
|
||||||
- 其他冲突组: 其他类型冲突记录
|
- 其他冲突组: 其他类型冲突记录
|
||||||
@@ -97,11 +97,12 @@
|
|||||||
### 处理规则
|
### 处理规则
|
||||||
|
|
||||||
** baseline是TIME
|
** baseline是TIME
|
||||||
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
- 只处理时间相关的内容,比如时间表达式、日期、时间点
|
||||||
** baseline不是TIME
|
-保留正确记录不变修改错误记录的expired_at为当前时间,比如(2025-12-16T12:00:00)
|
||||||
|
** baseline是FACT或者HYBRID
|
||||||
|
- 处理不是时间相关的内容
|
||||||
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||||
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||||
|
|
||||||
**核心原则**:
|
**核心原则**:
|
||||||
- 只输出需要修改的记录
|
- 只输出需要修改的记录
|
||||||
- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录
|
- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录
|
||||||
@@ -110,22 +111,26 @@
|
|||||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容,
|
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容,
|
||||||
,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据
|
,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据,如果存在隐私审核,隐私审核是true,也需要放到reflexion的reason字段和solution
|
||||||
|
|
||||||
**变更记录格式**:
|
**变更记录格式**:
|
||||||
```json
|
```json
|
||||||
"change": [
|
"change": [
|
||||||
{
|
{
|
||||||
"field": [
|
"field": [
|
||||||
{"id":修改字段对应的ID}
|
{"id": "修改字段对应的ID"},
|
||||||
{"statement_id":需要修改的对象对应的statement_id}
|
{"字段名1": ["修改前的值1", "修改后的值1"]},
|
||||||
{"字段名1": ["修改前的值1","修改后的值1"]},
|
{"字段名2": ["修改前的值2", "修改后的值2"]}
|
||||||
{"字段名2": ["修改前的值2","修改后的值2"]}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**resolved_memory格式说明**:
|
||||||
|
- 对于TIME类型冲突: 只需expired_at字段即可
|
||||||
|
- 对于FACT/HYBRID类型冲突: 需要包含完整的记录对象(包括name、entity1_name、entity2_name、description、relationship等所有相关字段)
|
||||||
|
- resolved_memory中只包含需要修改的记录,不需要修改的记录不要包含在内
|
||||||
|
|
||||||
**类型不匹配处理**:
|
**类型不匹配处理**:
|
||||||
- 冲突类型与baseline不匹配时,resolved设为null
|
- 冲突类型与baseline不匹配时,resolved设为null
|
||||||
- reflexion.reason说明类型不匹配原因
|
- reflexion.reason说明类型不匹配原因
|
||||||
@@ -157,7 +162,8 @@
|
|||||||
"conflict": true
|
"conflict": true
|
||||||
},
|
},
|
||||||
"reflexion": {
|
"reflexion": {
|
||||||
"reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,
|
"reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,如果
|
||||||
|
隐私审核打开的时候如果存在冲突,分析该冲突的原因
|
||||||
不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种",
|
不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种",
|
||||||
"solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)"
|
"solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ Classes:
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_config_logger, get_db_logger
|
from app.core.logging_config import get_config_logger, get_db_logger
|
||||||
from app.models.data_config_model import DataConfig
|
from app.models.data_config_model import DataConfig
|
||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
@@ -20,7 +20,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ConfigUpdateExtracted,
|
ConfigUpdateExtracted,
|
||||||
ConfigUpdateForget,
|
ConfigUpdateForget,
|
||||||
)
|
)
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
@@ -136,72 +136,88 @@ class DataConfigRepository:
|
|||||||
id: m.id
|
id: m.id
|
||||||
} AS targetNode
|
} AS targetNode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
|
def update_reflection_config(
|
||||||
|
db: Session,
|
||||||
|
config_id: int,
|
||||||
|
enable_self_reflexion: bool,
|
||||||
|
iteration_period: str,
|
||||||
|
reflexion_range: str,
|
||||||
|
baseline: str,
|
||||||
|
reflection_model_id: str,
|
||||||
|
memory_verify: bool,
|
||||||
|
quality_assessment: bool
|
||||||
|
) -> DataConfig:
|
||||||
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
quality_assessment:
|
||||||
|
memory_verify:
|
||||||
|
reflection_model_id:
|
||||||
|
baseline:
|
||||||
|
reflexion_range:
|
||||||
|
iteration_period:
|
||||||
|
enable_self_reflexion:
|
||||||
|
db: database object
|
||||||
config_id: 配置ID
|
config_id: 配置ID
|
||||||
**kwargs: 反思配置参数
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
Data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 没有字段需要更新时抛出
|
ValueError: 没有字段需要更新时抛出
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
||||||
|
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
|
||||||
|
data_config_obj = db.scalars(stmt).first()
|
||||||
|
if not data_config_obj:
|
||||||
|
raise BusinessException
|
||||||
|
data_config_obj.enable_self_reflexion = enable_self_reflexion
|
||||||
|
data_config_obj.iteration_period = iteration_period
|
||||||
|
data_config_obj.reflexion_range = reflexion_range
|
||||||
|
data_config_obj.baseline = baseline
|
||||||
|
data_config_obj.reflection_model_id = reflection_model_id
|
||||||
|
data_config_obj.memory_verify = memory_verify
|
||||||
|
data_config_obj.quality_assessment = quality_assessment
|
||||||
|
|
||||||
key_where = "config_id = :config_id"
|
return data_config_obj
|
||||||
set_fields: List[str] = []
|
|
||||||
params: Dict = {
|
|
||||||
"config_id": config_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 反思配置字段映射
|
|
||||||
mapping = {
|
|
||||||
"enable_self_reflexion": "enable_self_reflexion",
|
|
||||||
"iteration_period": "iteration_period",
|
|
||||||
"reflexion_range": "reflexion_range",
|
|
||||||
"baseline": "baseline",
|
|
||||||
"reflection_model_id": "reflection_model_id",
|
|
||||||
"memory_verify": "memory_verify",
|
|
||||||
"quality_assessment": "quality_assessment",
|
|
||||||
}
|
|
||||||
|
|
||||||
for api_field, db_col in mapping.items():
|
|
||||||
if api_field in kwargs and kwargs[api_field] is not None:
|
|
||||||
set_fields.append(f"{db_col} = :{api_field}")
|
|
||||||
params[api_field] = kwargs[api_field]
|
|
||||||
|
|
||||||
if not set_fields:
|
|
||||||
raise ValueError("No fields to update")
|
|
||||||
|
|
||||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
|
||||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
|
||||||
return query, params
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
|
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
db: database object
|
||||||
config_id: 配置ID
|
config_id: 配置ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
||||||
|
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
|
||||||
|
data_config = db.scalars(stmt).first()
|
||||||
|
if not data_config:
|
||||||
|
raise RuntimeError("reflection config not found")
|
||||||
|
return data_config
|
||||||
|
@staticmethod
|
||||||
|
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig:
|
||||||
|
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: database object
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||||
|
"""
|
||||||
|
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||||
|
|
||||||
|
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id)
|
||||||
|
data_config = db.scalars(stmt).first()
|
||||||
|
if not data_config:
|
||||||
|
raise RuntimeError("reflection config not found")
|
||||||
|
return data_config
|
||||||
|
|
||||||
query = (
|
|
||||||
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
|
||||||
f"reflection_model_id, memory_verify, quality_assessment, user_id "
|
|
||||||
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
|
|
||||||
)
|
|
||||||
params = {"config_id": config_id}
|
|
||||||
return query, params
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||||
|
|||||||
@@ -837,12 +837,14 @@ neo4j_query_part = """
|
|||||||
WITH DISTINCT m
|
WITH DISTINCT m
|
||||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||||
RETURN
|
RETURN
|
||||||
|
elementId(m) as id,
|
||||||
m.name as entity1_name,
|
m.name as entity1_name,
|
||||||
m.description as description,
|
m.description as description,
|
||||||
m.statement_id as statement_id,
|
m.statement_id as statement_id,
|
||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
m.expired_at as expired_at,
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
|
elementId(rel) as rel_id,
|
||||||
rel.predicate as predicate,
|
rel.predicate as predicate,
|
||||||
rel.statement as relationship,
|
rel.statement as relationship,
|
||||||
rel.statement_id as relationship_statement_id,
|
rel.statement_id as relationship_statement_id,
|
||||||
@@ -855,12 +857,14 @@ neo4j_query_all = """
|
|||||||
WITH DISTINCT m
|
WITH DISTINCT m
|
||||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||||
RETURN
|
RETURN
|
||||||
|
elementId(m) as id,
|
||||||
m.name as entity1_name,
|
m.name as entity1_name,
|
||||||
m.description as description,
|
m.description as description,
|
||||||
m.statement_id as statement_id,
|
m.statement_id as statement_id,
|
||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
m.expired_at as expired_at,
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
|
elementId(rel) as rel_id,
|
||||||
rel.predicate as predicate,
|
rel.predicate as predicate,
|
||||||
rel.statement as relationship,
|
rel.statement as relationship,
|
||||||
rel.statement_id as relationship_statement_id,
|
rel.statement_id as relationship_statement_id,
|
||||||
|
|||||||
@@ -11,22 +11,28 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
|
|||||||
update_databases: update
|
update_databases: update
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构建WHERE条件
|
# 构建WHERE条件 - 只使用elementId
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
for key, value in neo4j_dict_data.items():
|
# 优先使用id作为elementId进行查询
|
||||||
if value is not None:
|
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
|
||||||
param_name = f"param_{key}"
|
where_conditions.append(f"elementId(e) = $param_id")
|
||||||
where_conditions.append(f"e.{key} = ${param_name}")
|
params['param_id'] = neo4j_dict_data['id']
|
||||||
params[param_name] = value
|
else:
|
||||||
|
# 如果没有id,使用其他字段作为条件
|
||||||
|
for key, value in neo4j_dict_data.items():
|
||||||
|
if value is not None:
|
||||||
|
param_name = f"param_{key}"
|
||||||
|
where_conditions.append(f"e.{key} = ${param_name}")
|
||||||
|
params[param_name] = value
|
||||||
|
|
||||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||||
|
|
||||||
# 构建SET条件
|
# 构建SET条件 - 排除id字段
|
||||||
set_conditions = []
|
set_conditions = []
|
||||||
for key, value in update_databases.items():
|
for key, value in update_databases.items():
|
||||||
if value is not None:
|
if value is not None and key != 'id': # 不更新id字段
|
||||||
param_name = f"update_{key}"
|
param_name = f"update_{key}"
|
||||||
set_conditions.append(f"e.{key} = ${param_name}")
|
set_conditions.append(f"e.{key} = ${param_name}")
|
||||||
params[param_name] = value
|
params[param_name] = value
|
||||||
@@ -76,22 +82,28 @@ async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
|
|||||||
update_databases: update
|
update_databases: update
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构建WHERE条件
|
# 构建WHERE条件 - 只使用elementId
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
for key, value in neo4j_dict_data.items():
|
# 优先使用id作为elementId进行查询
|
||||||
if value is not None:
|
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
|
||||||
param_name = f"param_{key}"
|
where_conditions.append(f"elementId(r) = $param_id")
|
||||||
where_conditions.append(f"r.{key} = ${param_name}")
|
params['param_id'] = neo4j_dict_data['id']
|
||||||
params[param_name] = value
|
else:
|
||||||
|
# 如果没有id,使用其他字段作为条件
|
||||||
|
for key, value in neo4j_dict_data.items():
|
||||||
|
if value is not None:
|
||||||
|
param_name = f"param_{key}"
|
||||||
|
where_conditions.append(f"r.{key} = ${param_name}")
|
||||||
|
params[param_name] = value
|
||||||
|
|
||||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||||
|
|
||||||
# 构建SET条件
|
# 构建SET条件 - 排除id字段
|
||||||
set_conditions = []
|
set_conditions = []
|
||||||
for key, value in update_databases.items():
|
for key, value in update_databases.items():
|
||||||
if value is not None:
|
if value is not None and key != 'id': # 不更新id字段
|
||||||
param_name = f"update_{key}"
|
param_name = f"update_{key}"
|
||||||
set_conditions.append(f"r.{key} = ${param_name}")
|
set_conditions.append(f"r.{key} = ${param_name}")
|
||||||
params[param_name] = value
|
params[param_name] = value
|
||||||
@@ -242,7 +254,16 @@ async def neo4j_data(solved_data):
|
|||||||
if key=='expired_at':
|
if key=='expired_at':
|
||||||
updat_expired_at[key] = values[1]
|
updat_expired_at[key] = values[1]
|
||||||
|
|
||||||
elif key == 'statement_id':
|
elif key == 'id':
|
||||||
|
ori_edge[key] = values
|
||||||
|
updata_edge[key] = values
|
||||||
|
|
||||||
|
ori_entity[key] = values
|
||||||
|
updata_entity[key] = values
|
||||||
|
|
||||||
|
ori_expired_at[key] = values
|
||||||
|
elif key == 'rel_id':
|
||||||
|
key='id'
|
||||||
ori_edge[key] = values
|
ori_edge[key] = values
|
||||||
updata_edge[key] = values
|
updata_edge[key] = values
|
||||||
|
|
||||||
|
|||||||
@@ -35,10 +35,10 @@ class BaseDataSchema(BaseModel):
|
|||||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||||
|
|
||||||
# 新增字段以匹配实际输入数据
|
# 新增字段以匹配实际输入数据 - 改为可选以支持resolved_memory场景
|
||||||
entity1_name: str = Field(..., description="The first entity name.")
|
entity1_name: Optional[str] = Field(None, description="The first entity name.")
|
||||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||||
statement_id: str = Field(..., description="The statement identifier.")
|
statement_id: Optional[str] = Field(None, description="The statement identifier.")
|
||||||
# 新增字段 - 设为可选以保持向后兼容性
|
# 新增字段 - 设为可选以保持向后兼容性
|
||||||
predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.")
|
predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.")
|
||||||
relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.")
|
relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.")
|
||||||
@@ -108,13 +108,13 @@ class ChangeRecordSchema(BaseModel):
|
|||||||
"""Schema for individual change records
|
"""Schema for individual change records
|
||||||
|
|
||||||
字段值格式说明:
|
字段值格式说明:
|
||||||
- id 和 statement_id: 字符串或 None
|
- id: 字符串,表示修改字段对应的记录ID
|
||||||
- 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构
|
- 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构
|
||||||
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
|
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
|
||||||
"""
|
"""
|
||||||
field: List[Dict[str, Any]] = Field(
|
field: List[Dict[str, Any]] = Field(
|
||||||
...,
|
...,
|
||||||
description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
class ResolvedSchema(BaseModel):
|
class ResolvedSchema(BaseModel):
|
||||||
|
|||||||
@@ -120,10 +120,12 @@ class WorkspaceAppService:
|
|||||||
def _get_data_config(self, memory_content: str) -> Dict[str, Any]:
|
def _get_data_config(self, memory_content: str) -> Dict[str, Any]:
|
||||||
"""Retrieve data_comfig information based on memory_comtent"""
|
"""Retrieve data_comfig information based on memory_comtent"""
|
||||||
try:
|
try:
|
||||||
data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content)
|
data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
||||||
data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone()
|
|
||||||
if data_config_result is None:
|
# data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content)
|
||||||
return None
|
# data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone()
|
||||||
|
# if data_config_result is None:
|
||||||
|
# return None
|
||||||
|
|
||||||
if data_config_result:
|
if data_config_result:
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user