diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 8dfa6c50..c4800941 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -213,6 +213,7 @@ async def start_reflection_configs( @router.get("/reflection/run") async def reflection_run( config_id: int, + language_type: str = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -252,7 +253,8 @@ async def reflection_run( memory_verify=result.memory_verify, quality_assessment=result.quality_assessment, violation_handling_strategy="block", - model_id=model_id + model_id=model_id, + language_type=language_type ) connector = Neo4jConnector() engine = ReflectionEngine( diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index aa284a95..224a9560 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -48,7 +48,9 @@ if not _root_logger.handlers: else: _root_logger.setLevel(logging.INFO) - +class TranslationResponse(BaseModel): + """翻译响应模型""" + data: str class ReflectionRange(str, Enum): """反思范围枚举""" PARTIAL = "partial" # 从检索结果中反思 @@ -76,6 +78,7 @@ class ReflectionConfig(BaseModel): memory_verify: bool = True # 记忆验证 quality_assessment: bool = True # 质量评估 violation_handling_strategy: str = "warn" # 违规处理策略 + language_type: str = "zh" class Config: use_enum_values = True @@ -234,13 +237,11 @@ class ReflectionEngine: print(conflict_data) print(100 * '-') # # 检查是否真的有冲突 - has_conflict = conflict_data[0].get('conflict', False) - conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 - logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") + conflicts_found='' # 记录冲突数据 await self._log_data("conflict", conflict_data) - + conflicts_found='' # 3. 解决冲突 solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) if not solved_data: @@ -285,10 +286,60 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) + async def Translate(self, text): + # 翻译中文为英文 + translation_messages = [ + { + "role": "user", + "content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}" + } + ] + + response = await self.llm_client.response_structured( + messages=translation_messages, + response_model=TranslationResponse + ) + return response.data + async def extract_translation(self,data): + end_datas={} + end_datas['source_data']=await self.Translate(data['source_data']) + quality_assessments = [] + memory_verifies = [] + reflexion_data=[] + if data['memory_verifies']!=[]: + for i in data['memory_verifies']: + end_data={} + end_data['has_privacy'] = i['has_privacy'] + privacy=i['privacy_types'] + privacy_types_=[] + for pri in privacy: + privacy_types_.append(await self.Translate(pri)) + end_data['privacy_types']=privacy_types_ + end_data['summary']=await self.Translate(i['summary']) + memory_verifies.append(end_data) + end_datas['memory_verifies']=memory_verifies + + if data['quality_assessments']!=[]: + for i in data['quality_assessments']: + end_data = {} + end_data['score']=i['score'] + end_data['summary'] = await self.Translate(i['summary']) + quality_assessments.append(end_data) + end_datas['quality_assessments'] = quality_assessments + for i in data['reflexion_data']: + end_data = {} + end_data['reason'] = await self.Translate(i['reason']) + end_data['solution'] = await self.Translate(i['solution']) + reflexion_data.append(end_data) + end_datas['reflexion_data'] = reflexion_data + return end_datas + async def reflection_run(self): self._lazy_init() start_time = time.time() memory_verifies_flag = self.config.memory_verify + quality_assessment=self.config.quality_assessment + language_type=self.config.language_type asyncio.get_event_loop().time() logging.info("====== 自我反思流程开始 ======") @@ -297,9 +348,8 @@ class ReflectionEngine: source_data, databasets = await self.extract_fields_from_json() result_data['baseline'] = self.config.baseline - result_data[ - 'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" + result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(databasets, source_data) # 遍历数据提取字段 @@ -327,8 +377,6 @@ class ReflectionEngine: 'conflict': item['conflict'] } cleaned_conflict_data.append(cleaned_item) - print(cleaned_conflict_data) - # 3. 解决冲突 solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) if not solved_data: @@ -347,7 +395,12 @@ class ReflectionEngine: reflexion_data.append(result['reflexion']) result_data['reflexion_data'] = reflexion_data if memory_verifies_flag==False: - result_data['memory_verifies']=[None] + result_data['memory_verifies']=[] + if quality_assessment==False: + result_data['quality_assessments']=[] + + if language_type=='en': + result_data=await self.extract_translation(result_data) print(time.time()-start_time,'----------') return result_data @@ -431,6 +484,7 @@ class ReflectionEngine: logging.info("====== 冲突检测开始 ======") start_time = asyncio.get_event_loop().time() quality_assessment = self.config.quality_assessment + language_type=self.config.language_type try: # 渲染冲突检测提示词 @@ -440,7 +494,8 @@ class ReflectionEngine: self.config.baseline, memory_verify, quality_assessment, - statement_databasets + statement_databasets, + language_type ) messages = [{"role": "user", "content": rendered_prompt}] @@ -664,4 +719,8 @@ class ReflectionEngine: execution_time=time_result.execution_time + fact_result.execution_time ) else: - raise ValueError(f"未知的反思基线: {self.config.baseline}") \ No newline at end of file + + raise ValueError(f"未知的反思基线: {self.config.baseline}") + + + diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index b1293c1d..b292c804 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -6,7 +6,7 @@ - **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) - **隐私审核**: {{ memory_verify }} (true/false) - **质量评估**: {{ quality_assessment }} (true/false) - +- **语言类型**:{{language_type}}(zh/en) ## 任务目标 对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。 **数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。 @@ -23,7 +23,7 @@ - **身份冲突**: 同一实体被赋予不同类型或角色 ### 混合冲突 检测所有逻辑不一致或相互矛盾的记录。 -**检测原则**: +**检测原则**: - 重点检查相同实体的记录 - 分析description字段语义冲突 - 验证时间字段逻辑一致性 @@ -54,7 +54,7 @@ 1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组 2. **conflict=false**: 无冲突且无隐私信息时,data为空数组 3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立 -4. **条件输出**: +4. **条件输出**: - quality_assessment=true时输出评估对象,否则为null - memory_verify=true时输出隐私检测对象,否则为null 5. **不输出conflict_memory字段** @@ -63,7 +63,6 @@ 2. 隐私审核(如启用) → 将隐私记录加入data 3. 质量评估(如启用) → 独立输出评估结果 4. 去重data数组中的记录 - **输出结构**: ```json { @@ -82,6 +81,8 @@ ``` **字段说明**: - **data**: 包含冲突记录和隐私信息记录,无则为空数组 -- **quality_assessment**: quality_assessment=true时输出评估对象,否则为null +- **quality_assessment**: + quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) - **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null + (注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) 模式参考:{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 15f65fc3..36474d91 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -5,6 +5,7 @@ - **原始句子**: {{ statement_databasets }} - **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) - **隐私审核**: {{ memory_verify }} (true/false) +- **语言类型**:{{language_type}}(zh/en) ## 任务目标 作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。 @@ -110,6 +111,7 @@ - 隐私保护优先: 所有输出记录必须完成隐私脱敏 - 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} - 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 +- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容 **变更记录格式**: ```json @@ -181,5 +183,4 @@ - **resolved.change**: 包含详细变更信息 - 无需修改的冲突类型resolved为null - 与baseline不匹配的冲突类型不包含在results中 - 模式参考: {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 818d456a..46bb64e8 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -9,7 +9,8 @@ prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], baseline: str = "TIME", - memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str: + memory_verify: bool = False,quality_assessment:bool = False, + statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the evaluate prompt using the evaluate_optimized.jinja2 template. @@ -30,12 +31,13 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any baseline=baseline, memory_verify=memory_verify, quality_assessment=quality_assessment, - statement_databasets=statement_databasets + statement_databasets=statement_databasets, + language_type=language_type ) return rendered_prompt async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, - statement_databasets: List[str] = []) -> str: + statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. @@ -51,6 +53,6 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], rendered_prompt = template.render(data=data, json_schema=schema, baseline=baseline,memory_verify=memory_verify, - statement_databasets=statement_databasets) + statement_databasets=statement_databasets,language_type=language_type) return rendered_prompt