反思优化1.0(优化隐私输出、时间检索)

This commit is contained in:
lixinyue
2026-01-19 16:28:01 +08:00
parent 871304c89b
commit 5a0d3df689
9 changed files with 173 additions and 170 deletions

View File

@@ -10,7 +10,7 @@ Classes:
import uuid
from typing import Dict, List, Optional, Tuple
from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.schemas.memory_storage_schema import (
@@ -20,7 +20,7 @@ from app.schemas.memory_storage_schema import (
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from sqlalchemy import desc
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
# 获取数据库专用日志器
@@ -136,72 +136,88 @@ class DataConfigRepository:
id: m.id
} AS targetNode
"""
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
@staticmethod
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
def update_reflection_config(
db: Session,
config_id: int,
enable_self_reflexion: bool,
iteration_period: str,
reflexion_range: str,
baseline: str,
reflection_model_id: str,
memory_verify: bool,
quality_assessment: bool
) -> DataConfig:
"""构建反思配置更新语句SQLAlchemy text() 命名参数)
Args:
quality_assessment:
memory_verify:
reflection_model_id:
baseline:
reflexion_range:
iteration_period:
enable_self_reflexion:
db: database object
config_id: 配置ID
**kwargs: 反思配置参数
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Data
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config_obj = db.scalars(stmt).first()
if not data_config_obj:
raise BusinessException
data_config_obj.enable_self_reflexion = enable_self_reflexion
data_config_obj.iteration_period = iteration_period
data_config_obj.reflexion_range = reflexion_range
data_config_obj.baseline = baseline
data_config_obj.reflection_model_id = reflection_model_id
data_config_obj.memory_verify = memory_verify
data_config_obj.quality_assessment = quality_assessment
key_where = "config_id = :config_id"
set_fields: List[str] = []
params: Dict = {
"config_id": config_id,
}
# 反思配置字段映射
mapping = {
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
"reflection_model_id": "reflection_model_id",
"memory_verify": "memory_verify",
"quality_assessment": "quality_assessment",
}
for api_field, db_col in mapping.items():
if api_field in kwargs and kwargs[api_field] is not None:
set_fields.append(f"{db_col} = :{api_field}")
params[api_field] = kwargs[api_field]
if not set_fields:
raise ValueError("No fields to update")
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
return data_config_obj
@staticmethod
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
db: database object
config_id: 配置ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
stmt = select(DataConfig).where(DataConfig.config_id == config_id)
data_config = db.scalars(stmt).first()
if not data_config:
raise RuntimeError("reflection config not found")
return data_config
@staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args:
db: database object
workspace_id: 工作空间ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id)
data_config = db.scalars(stmt).first()
if not data_config:
raise RuntimeError("reflection config not found")
return data_config
query = (
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
f"reflection_model_id, memory_verify, quality_assessment, user_id "
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
)
params = {"config_id": config_id}
return query, params
@staticmethod
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:

View File

@@ -837,12 +837,14 @@ neo4j_query_part = """
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
elementId(m) as id,
m.name as entity1_name,
m.description as description,
m.statement_id as statement_id,
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
elementId(rel) as rel_id,
rel.predicate as predicate,
rel.statement as relationship,
rel.statement_id as relationship_statement_id,
@@ -855,12 +857,14 @@ neo4j_query_all = """
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
elementId(m) as id,
m.name as entity1_name,
m.description as description,
m.statement_id as statement_id,
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
elementId(rel) as rel_id,
rel.predicate as predicate,
rel.statement as relationship,
rel.statement_id as relationship_statement_id,

View File

@@ -11,22 +11,28 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
update_databases: update
"""
try:
# 构建WHERE条件
# 构建WHERE条件 - 只使用elementId
where_conditions = []
params = {}
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"e.{key} = ${param_name}")
params[param_name] = value
# 优先使用id作为elementId进行查询
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
where_conditions.append(f"elementId(e) = $param_id")
params['param_id'] = neo4j_dict_data['id']
else:
# 如果没有id使用其他字段作为条件
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"e.{key} = ${param_name}")
params[param_name] = value
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 构建SET条件
# 构建SET条件 - 排除id字段
set_conditions = []
for key, value in update_databases.items():
if value is not None:
if value is not None and key != 'id': # 不更新id字段
param_name = f"update_{key}"
set_conditions.append(f"e.{key} = ${param_name}")
params[param_name] = value
@@ -76,22 +82,28 @@ async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
update_databases: update
"""
try:
# 构建WHERE条件
# 构建WHERE条件 - 只使用elementId
where_conditions = []
params = {}
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"r.{key} = ${param_name}")
params[param_name] = value
# 优先使用id作为elementId进行查询
if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None:
where_conditions.append(f"elementId(r) = $param_id")
params['param_id'] = neo4j_dict_data['id']
else:
# 如果没有id使用其他字段作为条件
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"r.{key} = ${param_name}")
params[param_name] = value
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 构建SET条件
# 构建SET条件 - 排除id字段
set_conditions = []
for key, value in update_databases.items():
if value is not None:
if value is not None and key != 'id': # 不更新id字段
param_name = f"update_{key}"
set_conditions.append(f"r.{key} = ${param_name}")
params[param_name] = value
@@ -242,7 +254,16 @@ async def neo4j_data(solved_data):
if key=='expired_at':
updat_expired_at[key] = values[1]
elif key == 'statement_id':
elif key == 'id':
ori_edge[key] = values
updata_edge[key] = values
ori_entity[key] = values
updata_entity[key] = values
ori_expired_at[key] = values
elif key == 'rel_id':
key='id'
ori_edge[key] = values
updata_edge[key] = values