From 5a0d3df689805be0d25c8a7356f97b855dde45a1 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Mon, 19 Jan 2026 16:28:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=8D=E6=80=9D=E4=BC=98=E5=8C=961.0?= =?UTF-8?q?=EF=BC=88=E4=BC=98=E5=8C=96=E9=9A=90=E7=A7=81=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E3=80=81=E6=97=B6=E9=97=B4=E6=A3=80=E7=B4=A2=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../memory_reflection_controller.py | 95 +++++----------- .../reflection_engine/self_reflexion.py | 36 +++---- .../utils/prompt/prompts/evaluate.jinja2 | 3 +- .../utils/prompt/prompts/reflexion.jinja2 | 28 +++-- .../repositories/data_config_repository.py | 102 ++++++++++-------- api/app/repositories/neo4j/cypher_queries.py | 4 + api/app/repositories/neo4j/neo4j_update.py | 55 +++++++--- api/app/schemas/memory_storage_schema.py | 10 +- api/app/services/memory_reflection_service.py | 10 +- 9 files changed, 173 insertions(+), 170 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index b0287d80..24c143b9 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,10 +1,11 @@ import asyncio import time +import uuid from app.core.logging_config import get_api_logger from app.core.memory.storage_services.reflection_engine.self_reflexion import ( ReflectionConfig, - ReflectionEngine, + ReflectionEngine, ReflectionRange, ReflectionBaseline, ) from app.core.response_utils import success from app.db import get_db @@ -39,9 +40,6 @@ async def save_reflection_config( db: Session = Depends(get_db), ) -> dict: """Save reflection configuration to data_comfig table""" - - - try: config_id = request.config_id if not config_id: @@ -52,51 +50,30 @@ async def save_reflection_config( api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") - update_params = { - "enable_self_reflexion": request.reflection_enabled, - "iteration_period": request.reflection_period_in_hours, - "reflexion_range": request.reflexion_range, - "baseline": request.baseline, - "reflection_model_id": request.reflection_model_id, - "memory_verify": request.memory_verify, - "quality_assessment": request.quality_assessment, - } + data_config = DataConfigRepository.update_reflection_config( + db, + config_id=config_id, + enable_self_reflexion=request.reflection_enabled, + iteration_period=request.reflection_period_in_hours, + reflexion_range=request.reflexion_range, + baseline=request.baseline, + reflection_model_id=request.reflection_model_id, + memory_verify=request.memory_verify, + quality_assessment=request.quality_assessment + ) - - - query, params = DataConfigRepository.build_update_reflection(config_id, **update_params) - - result = db.execute(text(query), params) - if result.rowcount == 0: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - db.commit() - - # 查询更新后的配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"更新后未找到config_id为 {config_id} 的配置" - ) - - api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") + db.refresh(data_config) reflection_result={ - "config_id": result.config_id, - "enable_self_reflexion": result.enable_self_reflexion, - "iteration_period": result.iteration_period, - "reflexion_range": result.reflexion_range, - "baseline": result.baseline, - "reflection_model_id": result.reflection_model_id, - "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id} + "config_id": data_config.config_id, + "enable_self_reflexion": data_config.enable_self_reflexion, + "iteration_period": data_config.iteration_period, + "reflexion_range": data_config.reflexion_range, + "baseline": data_config.baseline, + "reflection_model_id": data_config.reflection_model_id, + "memory_verify": data_config.memory_verify, + "quality_assessment": data_config.quality_assessment} return success(data=reflection_result, msg="反思配置成功") @@ -116,9 +93,8 @@ async def save_reflection_config( ) -@router.post("/reflection") +@router.get("/reflection") async def start_workspace_reflection( - config_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -178,17 +154,7 @@ async def start_reflection_configs( """通过config_id查询data_config表中的反思配置信息""" try: api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") - - # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) # 构建返回数据 reflection_config = { "config_id": result.config_id, @@ -198,8 +164,7 @@ async def start_reflection_configs( "baseline": result.baseline, "reflection_model_id": result.reflection_model_id, "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id + "quality_assessment": result.quality_assessment } api_logger.info(f"成功查询反思配置,config_id: {config_id}") return success(data=reflection_config, msg="反思配置查询成功") @@ -227,9 +192,7 @@ async def reflection_run( api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -242,7 +205,7 @@ async def reflection_run( model_id = result.reflection_model_id if model_id: try: - ModelConfigService.get_model_by_id(db=db, model_id=model_id) + ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id)) api_logger.info(f"模型ID验证成功: {model_id}") except Exception as e: api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}") @@ -252,8 +215,8 @@ async def reflection_run( config = ReflectionConfig( enabled=result.enable_self_reflexion, iteration_period=result.iteration_period, - reflexion_range=result.reflexion_range, - baseline=result.baseline, + reflexion_range=ReflectionRange(result.reflexion_range), + baseline=ReflectionBaseline(result.baseline), output_example='', memory_verify=result.memory_verify, quality_assessment=result.quality_assessment, diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index e9fb8855..bd3a9190 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -24,15 +24,9 @@ from app.core.memory.utils.config.get_data import ( get_data, get_data_statement, ) -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.prompt.template_render import ( - render_evaluate_prompt, - render_reflexion_prompt, -) + from app.core.models.base import RedBearModelConfig -from app.core.response_utils import success from app.repositories.neo4j.cypher_queries import ( - UPDATE_STATEMENT_INVALID_AT, neo4j_query_all, neo4j_query_part, neo4j_statement_all, @@ -160,12 +154,11 @@ class ReflectionEngine: self.neo4j_connector = Neo4jConnector() if self.llm_client is None: - from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context with get_db_context() as db: factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) + self.llm_client = factory.get_llm_client(self.config.model_id) elif isinstance(self.llm_client, str): # 如果 llm_client 是字符串(model_id),则用它初始化客户端 from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -263,25 +256,23 @@ class ReflectionEngine: # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) - print(100 * '-') - print(conflict_data) - print(100 * '-') - # # 检查是否真的有冲突 - conflicts_found='' + conflict_list=[] + for i in conflict_data: + conflict_list.append(i['data']) - conflicts_found='' + + + conflicts_found=0 # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) + solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) + if not solved_data: return ReflectionResult( success=False, - message="反思失败,未解决冲突", + message=f"没有{self.config.baseline}相关的冲突数据", conflicts_found=conflicts_found, execution_time=asyncio.get_event_loop().time() - start_time ) - print(100 * '*') - print(solved_data) - print(100 * '*') conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") @@ -386,7 +377,7 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - conflicts_found='' + conflicts_found = 0 # 初始化为整数0而不是空字符串 REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} # Clearn conflict_data,And memory_verify和quality_assessment cleaned_conflict_data = [] @@ -414,7 +405,7 @@ class ReflectionEngine: cleaned_conflict_data_.append(cleaned_item) print(cleaned_conflict_data_) # 3. 解决冲突 - solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) + solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) if not solved_data: return ReflectionResult( success=False, @@ -739,4 +730,3 @@ class ReflectionEngine: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index e649897a..5da6d4b5 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -24,7 +24,8 @@ - **身份冲突**: 同一实体被赋予不同类型或角色 - **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 混合冲突 -检测所有逻辑不一致或相互矛盾的记录。 +- 检测所有逻辑不一致或相互矛盾的记录。 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 **检测原则**: - 重点检查相同实体的记录 - 分析description字段语义冲突 diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index ed3aad32..99660aa4 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -63,7 +63,7 @@ **脱敏字段**: name、entity1_name、entity2_name、description、relationship ## 4. 处理流程 - +###如果存在冲突数据执行以下步骤,不存在返回【】在data中 ### 步骤1: 类型匹配验证 **匹配规则**: - baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点) @@ -78,7 +78,7 @@ ### 步骤2: 冲突数据分组 **分组策略**: -- 时间冲突组: 涉及用户时间的记录 +- 时间冲突组: 涉及用户时间的记录比如(生日在2月17...) - 活动时间冲突组: 同一活动不同时间的记录 - 事实冲突组: 同一实体不同属性的记录 - 其他冲突组: 其他类型冲突记录 @@ -97,11 +97,12 @@ ### 处理规则 ** baseline是TIME - -保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的 -** baseline不是TIME + - 只处理时间相关的内容,比如时间表达式、日期、时间点 + -保留正确记录不变修改错误记录的expired_at为当前时间,比如(2025-12-16T12:00:00) +** baseline是FACT或者HYBRID + - 处理不是时间相关的内容 - 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段, 如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field - **核心原则**: - 只输出需要修改的记录 - 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录 @@ -110,22 +111,26 @@ - 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} - 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 - 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容, - ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据 + ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据,如果存在隐私审核,隐私审核是true,也需要放到reflexion的reason字段和solution **变更记录格式**: ```json "change": [ { "field": [ - {"id":修改字段对应的ID} - {"statement_id":需要修改的对象对应的statement_id} - {"字段名1": ["修改前的值1","修改后的值1"]}, - {"字段名2": ["修改前的值2","修改后的值2"]} + {"id": "修改字段对应的ID"}, + {"字段名1": ["修改前的值1", "修改后的值1"]}, + {"字段名2": ["修改前的值2", "修改后的值2"]} ] } ] ``` +**resolved_memory格式说明**: +- 对于TIME类型冲突: 只需expired_at字段即可 +- 对于FACT/HYBRID类型冲突: 需要包含完整的记录对象(包括name、entity1_name、entity2_name、description、relationship等所有相关字段) +- resolved_memory中只包含需要修改的记录,不需要修改的记录不要包含在内 + **类型不匹配处理**: - 冲突类型与baseline不匹配时,resolved设为null - reflexion.reason说明类型不匹配原因 @@ -157,7 +162,8 @@ "conflict": true }, "reflexion": { - "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析, + "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,如果 + 隐私审核打开的时候如果存在冲突,分析该冲突的原因 不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种", "solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)" }, diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index 7843acc2..ea9fadea 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -10,7 +10,7 @@ Classes: import uuid from typing import Dict, List, Optional, Tuple - +from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.data_config_model import DataConfig from app.schemas.memory_storage_schema import ( @@ -20,7 +20,7 @@ from app.schemas.memory_storage_schema import ( ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc +from sqlalchemy import desc, select from sqlalchemy.orm import Session # 获取数据库专用日志器 @@ -136,72 +136,88 @@ class DataConfigRepository: id: m.id } AS targetNode """ - - # ==================== SQLAlchemy ORM 数据库操作方法 ==================== @staticmethod - def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]: + def update_reflection_config( + db: Session, + config_id: int, + enable_self_reflexion: bool, + iteration_period: str, + reflexion_range: str, + baseline: str, + reflection_model_id: str, + memory_verify: bool, + quality_assessment: bool + ) -> DataConfig: """构建反思配置更新语句(SQLAlchemy text() 命名参数) Args: + quality_assessment: + memory_verify: + reflection_model_id: + baseline: + reflexion_range: + iteration_period: + enable_self_reflexion: + db: database object config_id: 配置ID - **kwargs: 反思配置参数 Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) + Data Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config_obj = db.scalars(stmt).first() + if not data_config_obj: + raise BusinessException + data_config_obj.enable_self_reflexion = enable_self_reflexion + data_config_obj.iteration_period = iteration_period + data_config_obj.reflexion_range = reflexion_range + data_config_obj.baseline = baseline + data_config_obj.reflection_model_id = reflection_model_id + data_config_obj.memory_verify = memory_verify + data_config_obj.quality_assessment = quality_assessment - key_where = "config_id = :config_id" - set_fields: List[str] = [] - params: Dict = { - "config_id": config_id, - } - - # 反思配置字段映射 - mapping = { - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - "reflection_model_id": "reflection_model_id", - "memory_verify": "memory_verify", - "quality_assessment": "quality_assessment", - } - - for api_field, db_col in mapping.items(): - if api_field in kwargs and kwargs[api_field] is not None: - set_fields.append(f"{db_col} = :{api_field}") - params[api_field] = kwargs[api_field] - - if not set_fields: - raise ValueError("No fields to update") - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params + return data_config_obj @staticmethod - def build_select_reflection(config_id: int) -> Tuple[str, Dict]: + def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: + db: database object config_id: 配置ID Returns: Tuple[str, Dict]: (SQL查询字符串, 参数字典) """ db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config + @staticmethod + def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig: + """构建查询所有配置的语句(SQLAlchemy text() 命名参数) + + Args: + db: database object + workspace_id: 工作空间ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") + + stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config - query = ( - f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, " - f"reflection_model_id, memory_verify, quality_assessment, user_id " - f"FROM {TABLE_NAME} WHERE config_id = :config_id" - ) - params = {"config_id": config_id} - return query, params @staticmethod def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index fb28b81e..aac591b3 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -837,12 +837,14 @@ neo4j_query_part = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, @@ -855,12 +857,14 @@ neo4j_query_all = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py index 73b44396..753ae256 100644 --- a/api/app/repositories/neo4j/neo4j_update.py +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -11,22 +11,28 @@ async def update_neo4j_data(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"e.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(e) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"e.{key} = ${param_name}") params[param_name] = value @@ -76,22 +82,28 @@ async def update_neo4j_data_edge(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"r.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(r) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"r.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"r.{key} = ${param_name}") params[param_name] = value @@ -242,7 +254,16 @@ async def neo4j_data(solved_data): if key=='expired_at': updat_expired_at[key] = values[1] - elif key == 'statement_id': + elif key == 'id': + ori_edge[key] = values + updata_edge[key] = values + + ori_entity[key] = values + updata_entity[key] = values + + ori_expired_at[key] = values + elif key == 'rel_id': + key='id' ori_edge[key] = values updata_edge[key] = values diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index ecb1570f..d17a9f2c 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -35,10 +35,10 @@ class BaseDataSchema(BaseModel): expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") description: Optional[str] = Field(None, description="The description of the data entry.") - # 新增字段以匹配实际输入数据 - entity1_name: str = Field(..., description="The first entity name.") + # 新增字段以匹配实际输入数据 - 改为可选以支持resolved_memory场景 + entity1_name: Optional[str] = Field(None, description="The first entity name.") entity2_name: Optional[str] = Field(None, description="The second entity name.") - statement_id: str = Field(..., description="The statement identifier.") + statement_id: Optional[str] = Field(None, description="The statement identifier.") # 新增字段 - 设为可选以保持向后兼容性 predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.") relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.") @@ -108,13 +108,13 @@ class ChangeRecordSchema(BaseModel): """Schema for individual change records 字段值格式说明: - - id 和 statement_id: 字符串或 None + - id: 字符串,表示修改字段对应的记录ID - 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构 - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( ..., - description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" + description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) class ResolvedSchema(BaseModel): diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 0f8fb569..015cc08a 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -120,10 +120,12 @@ class WorkspaceAppService: def _get_data_config(self, memory_content: str) -> Dict[str, Any]: """Retrieve data_comfig information based on memory_comtent""" try: - data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) - data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() - if data_config_result is None: - return None + data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) + + # data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) + # data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() + # if data_config_result is None: + # return None if data_config_result: return {