From b95b39e4351adc69cfe6aa2fc33200fd8ebf576a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Mon, 22 Dec 2025 02:21:50 +0000 Subject: [PATCH 01/15] Merge #24 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 反思速度提升,从4分钟优化成1分10-40秒 * fix/memory_reflection: (40 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 统一输出 - 统一输出 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py - 统一输出 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 更新 self_reflexion.py Signed-off-by: aliyun8644380055 Commented-by: aliyun8644380055 Commented-by: aliyun6762716068 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/24 --- .../reflection_engine/example/example.json | 107 +----- .../reflection_engine/self_reflexion.py | 73 +++- .../utils/prompt/prompts/evaluate.jinja2 | 285 ++++---------- .../utils/prompt/prompts/reflexion.jinja2 | 358 ++++++------------ api/app/schemas/memory_storage_schema.py | 5 - 5 files changed, 266 insertions(+), 562 deletions(-) diff --git a/api/app/core/memory/storage_services/reflection_engine/example/example.json b/api/app/core/memory/storage_services/reflection_engine/example/example.json index 6528da60..fe7a3816 100644 --- a/api/app/core/memory/storage_services/reflection_engine/example/example.json +++ b/api/app/core/memory/storage_services/reflection_engine/example/example.json @@ -3,53 +3,43 @@ "source_data": [ { "statement_name": "用户是2023年春天去北京工作的。", - "statement_id": "62beac695b1346f4871740a45db88782", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "62beac695b1346f4871740a45db88782" }, { "statement_name": "用户后来基本一直都在北京上班。", - "statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b" }, { "statement_name": "用户从2023年开始就一直在北京生活。", - "statement_id": "e612a44da4db483993c350df7c97a1a1", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "e612a44da4db483993c350df7c97a1a1" }, { "statement_name": "用户从来没有长期离开过北京。", - "statement_id": "b3c787a2e33c49f7981accabbbb4538a", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "b3c787a2e33c49f7981accabbbb4538a" }, { "statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。", - "statement_id": "64cde4230cb24a4da726e7db9e7aa616", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "64cde4230cb24a4da726e7db9e7aa616" }, { "statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。", - "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "8b1b12e23b844b8088dfeb67da6ad669" }, { "statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。", - "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "030afd362e9b4110b139e68e5d3e7143" }, { "statement_name": "用户的银行卡号是6222023847595898。", - "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f" }, { "statement_name": "用户的身份信息和银行卡信息一直没变。", - "statement_id": "b3ca618e1e204b83bebd70e75cf2073f", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "b3ca618e1e204b83bebd70e75cf2073f" }, { "statement_name": "用户认为在上海的那段时间更多算是远程配合。", - "statement_id": "150af89d2c154e6eb41ff1a91e37f962", - "statement_created_at": "2025-12-19T10:31:15.239252" + "statement_id": "150af89d2c154e6eb41ff1a91e37f962" } ], "databasets": [ @@ -57,24 +47,11 @@ "entity1_name": "Person", "description": "表示人类个体的通用类型", "statement_id": "62beac695b1346f4871740a45db88782", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "用户", "entity2": { - "entity_idx": 0, - "run_id": "62b59cfebeea43dd94d91763056f069a", - "connect_strength": "strong", - "created_at": "2025-12-19T10:31:15.239252000", "description": "叙述者,讲述个人工作与生活经历的个体", "statement_id": "62beac695b1346f4871740a45db88782", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Person", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "用户", - "apply_id": "88a459f5_text08", "id": "3d3896797b334572a80d57590026063d" } }, @@ -82,24 +59,11 @@ "entity1_name": "用户", "description": "叙述者,讲述个人工作与生活经历的个体", "statement_id": "62beac695b1346f4871740a45db88782", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "身份信息", "entity2": { - "entity_idx": 1, - "run_id": "62b59cfebeea43dd94d91763056f069a", - "connect_strength": "Strong", "description": "用于个人身份识别的数据", - "created_at": "2025-12-19T10:31:15.239252000", "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Information", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "身份信息", - "apply_id": "88a459f5_text08", "id": "aa766a517e82490599a9b3af54cfd933" } }, @@ -107,24 +71,11 @@ "entity1_name": "用户", "description": "叙述者,讲述个人工作与生活经历的个体", "statement_id": "62beac695b1346f4871740a45db88782", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "6222023847595898", "entity2": { - "entity_idx": 1, - "run_id": "62b59cfebeea43dd94d91763056f069a", - "connect_strength": "Strong", "description": "用户的银行卡号码", - "created_at": "2025-12-19T10:31:15.239252000", "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Numeric", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "6222023847595898", - "apply_id": "88a459f5_text08", "id": "610ba361918f4e68a65ce6ad06e5c7a0" } }, @@ -132,25 +83,13 @@ "entity1_name": "用户", "description": "叙述者,讲述个人工作与生活经历的个体", "statement_id": "62beac695b1346f4871740a45db88782", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "上海办公室", "entity2": { "entity_idx": 1, - "run_id": "62b59cfebeea43dd94d91763056f069a", "aliases": ["上海办"], - "connect_strength": "Strong", - "created_at": "2025-12-19T10:31:15.239252000", "description": "位于上海的工作办公场所", "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Location", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "上海办公室", - "apply_id": "88a459f5_text08", "id": "fb702ef695c14e14af3e56786bc8815b" } }, @@ -158,25 +97,12 @@ "entity1_name": "用户", "description": "叙述者,讲述个人工作与生活经历的个体", "statement_id": "62beac695b1346f4871740a45db88782", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "北京", "entity2": { - "entity_idx": 2, - "run_id": "62b59cfebeea43dd94d91763056f069a", "aliases": ["京", "京城", "北平"], - "connect_strength": "strong", - "created_at": "2025-12-19T10:31:15.239252000", "description": "中国的首都城市,用户主要工作和生活所在地", "statement_id": "62beac695b1346f4871740a45db88782", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Location", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "北京", - "apply_id": "88a459f5_text08", "id": "81b2d1a571bb46a08a2d7a1e87efb945" } }, @@ -184,24 +110,11 @@ "entity1_name": "11010119950308123X", "description": "具体的身份证号码值", "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "created_at": "2025-12-19T10:31:15.239252000", - "expired_at": "9999-12-31T00:00:00.000000000", - "relationship_type": "EXTRACTED_RELATIONSHIP", - "relationship": {}, "entity2_name": "身份证号", "entity2": { - "entity_idx": 2, - "run_id": "62b59cfebeea43dd94d91763056f069a", - "connect_strength": "strong", "description": "中华人民共和国公民的身份号码", - "created_at": "2025-12-19T10:31:15.239252000", "statement_id": "030afd362e9b4110b139e68e5d3e7143", - "expired_at": "9999-12-31T00:00:00.000000000", - "entity_type": "Identifier", - "group_id": "88a459f5_text08", - "user_id": "88a459f5_text08", "name": "身份证号", - "apply_id": "88a459f5_text08", "id": "3e5f920645b2404fadb0e9ff60d1306e" } } diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 6ccec500..864c91a7 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -19,10 +19,32 @@ import uuid from pydantic import BaseModel + from app.core.response_utils import success from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all from app.repositories.neo4j.neo4j_update import neo4j_data + +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.utils.config import definitions as config_defs +from app.core.memory.utils.config import get_model_config +from app.core.memory.utils.config.get_data import get_data +from app.core.memory.utils.config.get_data import get_data_statement +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.utils.prompt.template_render import render_evaluate_prompt +from app.core.memory.utils.prompt.template_render import render_reflexion_prompt +from app.core.models.base import RedBearModelConfig +from app.repositories.neo4j.cypher_queries import ( + neo4j_query_all, + neo4j_query_part, + neo4j_statement_all, + neo4j_statement_part, +) +from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.repositories.neo4j.neo4j_update import neo4j_data +from app.schemas.memory_storage_schema import ConflictResultSchema +from app.schemas.memory_storage_schema import ReflexionResultSchema + # 配置日志 _root_logger = logging.getLogger() @@ -122,6 +144,7 @@ class ReflectionEngine: self.update_query = update_query self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 + # 延迟导入以避免循环依赖 self._lazy_init_done = False @@ -131,46 +154,53 @@ class ReflectionEngine: return if self.neo4j_connector is None: - from app.repositories.neo4j.neo4j_connector import Neo4jConnector self.neo4j_connector = Neo4jConnector() if self.llm_client is None: - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.utils.config import definitions as config_defs self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) elif isinstance(self.llm_client, str): # 如果 llm_client 是字符串(model_id),则用它初始化客户端 - from app.core.memory.utils.llm.llm_utils import get_llm_client - model_id = self.llm_client - self.llm_client = get_llm_client(model_id) + # from app.core.memory.utils.llm.llm_utils import get_llm_client + # model_id = self.llm_client + # self.llm_client = get_llm_client(model_id) + extra_params={ + "temperature": 0.2, # 降低温度提高响应速度和一致性 + "max_tokens": 600, # 限制最大token数 + "top_p": 0.8, # 优化采样参数 + "stream": False, # 确保非流式输出以获得最快响应 + } + + model_config = get_model_config(self.llm_client) + self.llm_client = OpenAIClient(RedBearModelConfig( + model_name=model_config.get("model_name"), + provider=model_config.get("provider"), + api_key=model_config.get("api_key"), + base_url=model_config.get("base_url"), + timeout=model_config.get("timeout", 30), + max_retries=model_config.get("max_retries", 2), + extra_params=extra_params + ), type_=model_config.get("type")) if self.get_data_func is None: - from app.core.memory.utils.config.get_data import get_data self.get_data_func = get_data # 导入get_data_statement函数 if not hasattr(self, 'get_data_statement'): - from app.core.memory.utils.config.get_data import get_data_statement self.get_data_statement = get_data_statement if self.render_evaluate_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_evaluate_prompt self.render_evaluate_prompt_func = render_evaluate_prompt if self.render_reflexion_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_reflexion_prompt self.render_reflexion_prompt_func = render_reflexion_prompt if self.conflict_schema is None: - from app.schemas.memory_storage_schema import ConflictResultSchema self.conflict_schema = ConflictResultSchema if self.reflexion_schema is None: - from app.schemas.memory_storage_schema import ReflexionResultSchema self.reflexion_schema = ReflexionResultSchema if self.update_query is None: - from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT self.update_query = UPDATE_STATEMENT_INVALID_AT self._lazy_init_done = True @@ -284,7 +314,6 @@ class ReflectionEngine: quality_assessments = [] memory_verifies = [] for item in conflict_data: - print(item) quality_assessments.append(item['quality_assessment']) memory_verifies.append(item['memory_verify']) result_data['quality_assessments'] = quality_assessments @@ -298,8 +327,18 @@ class ReflectionEngine: # 记录冲突数据 await self._log_data("conflict", conflict_data) + # Clearn conflict_data,And memory_verify和quality_assessment + cleaned_conflict_data = [] + for item in conflict_data: + cleaned_item = { + 'data': item['data'], + 'conflict': item['conflict'] + } + cleaned_conflict_data.append(cleaned_item) + print(cleaned_conflict_data) + # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data, source_data) + solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) if not solved_data: return ReflectionResult( success=False, @@ -391,7 +430,7 @@ class ReflectionEngine: return [] # 使用转换后的数据 - print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 + # print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 memory_verify = self.config.memory_verify logging.info("====== 冲突检测开始 ======") @@ -469,6 +508,7 @@ class ReflectionEngine: memory_verify, statement_databasets ) + logging.info(f"提示词长度: {len(rendered_prompt)}") messages = [{"role": "user", "content": rendered_prompt}] @@ -629,4 +669,3 @@ class ReflectionEngine: ) else: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index e1ecf820..b1293c1d 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -1,222 +1,87 @@ -你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数: -原本的输入句子:{{statement_databasets}} -需要检测冲突对象:{{ evaluate_data }} -冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID) -记忆审核开关:{{ memory_verify }}(取值为 true / false) -记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false) +# 记忆数据分析任务 -你的任务是: -对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果 -数据的结构: - statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, - 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估) -## 冲突定义 +## 输入数据 +- **原始句子**: {{statement_databasets}} +- **检测对象**: {{ evaluate_data }} +- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) +- **隐私审核**: {{ memory_verify }} (true/false) +- **质量评估**: {{ quality_assessment }} (true/false) +## 任务目标 +对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。 +**数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。 +## 1. 冲突检测 ### 时间冲突 -时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾: - -1. **同一活动的时间冲突**: - - 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球") - - 同一用户在同一时间段内被记录进行不同的互斥活动 - -2. **时间逻辑错误**: - - expired_at 早于 created_at - - 同一事实的 created_at 时间差异超过合理误差范围(>5分钟) - -3. **日期属性冲突**: - - 同一人的生日记录为不同日期(如"2月10号"和"2月16号") -4.存在明确先后约束 A -> B,但 t(A) > t(B) - -例:入学时间晚于毕业时间。 - -处理:标记异常、降权、触发逻辑反思或人工审查。 -5.时间属性冲突 - -单值日期属性出现多值(生日、入职日期) - -注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。 -6.互斥重叠冲突 - -例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地) - -处理:证据仲裁、保留多版本(active + candidate)。 - - - +- **同一活动时间矛盾**: 同一用户同一活动的不同时间记录 +- **时间逻辑错误**: expired_at < created_at,created_at时间差>5分钟 +- **日期属性冲突**: 同一人的生日等单值属性出现多值 +- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业) +- **互斥重叠**: 同一时间出现在不同地点等互斥事件 ### 事实冲突 -事实冲突是指同一实体的属性或关系存在相互矛盾的陈述: +- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢) +- **关系矛盾**: 同一实体在相同语境下的不同关系描述 +- **身份冲突**: 同一实体被赋予不同类型或角色 +### 混合冲突 +检测所有逻辑不一致或相互矛盾的记录。 +**检测原则**: +- 重点检查相同实体的记录 +- 分析description字段语义冲突 +- 验证时间字段逻辑一致性 +## 2. 隐私审核 (memory_verify=true时) +### 隐私信息类型 +- **身份信息**: 身份证号码、身份证相关描述 +- **联系方式**: 手机号、电话号码 +- **社交账号**: 微信号、QQ号、邮箱地址 +- **金融信息**: 银行卡号、账户信息、支付信息 +- **税务信息**: 税号、纳税信息、发票信息 +- **贷款信息**: 贷款记录、信贷信息 +- **安全信息**: 密码、PIN码、验证码 +### 检测方法 +- 检测description、entity1_name、entity2_name、name等字段 +- 识别数字模式(手机号11位、身份证18位等) +- 识别关键词("身份证"、"银行卡"、"密码"等) +## 3. 质量评估 (quality_assessment=true时) +### 评估标准 +- **数据完整性**: 必要字段完整性、关系描述清晰度、时间字段有效性 +- **重复检测**: 相同或高度相似记录、冗余实体关系、描述重复度 +- **无意义检测**: 空值/无效值、过于简单的描述、格式错误 +- **上下文依赖**: 记录自包含性、实体名称明确性 +### 输出内容 +- **质量分数**: 0-100的整体质量百分比 +- **质量概述**: 简要描述数据质量状况和主要问题 +## 输出规则 +### 核心原则 +1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组 +2. **conflict=false**: 无冲突且无隐私信息时,data为空数组 +3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立 +4. **条件输出**: + - quality_assessment=true时输出评估对象,否则为null + - memory_verify=true时输出隐私检测对象,否则为null +5. **不输出conflict_memory字段** +### 处理流程 +1. 冲突检测 → 将冲突记录加入data +2. 隐私审核(如启用) → 将隐私记录加入data +3. 质量评估(如启用) → 独立输出评估结果 +4. 去重data数组中的记录 -1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) -2. **关系矛盾**:同一实体在相同语境下的不同关系描述 -3. **身份冲突**:同一实体被赋予不同的类型或角色 - -### 混合冲突检测 -检测所有类型的冲突,包括但不限于时间冲突和事实冲突: -检测任何逻辑上不一致或相互矛盾的记录 -## 记忆审核定义 - -### 隐私信息检测(隐私冲突) -当memory_verify为true时,需要额外检测包含个人隐私信息的记录: - -1. **身份证信息**:包含身份证号码、身份证相关描述 -2. **手机号码**:包含手机号、电话号码等联系方式 -3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 -4. **银行信息**:包含银行卡号、账户信息、支付信息 -5. **税务信息**:包含税号、纳税信息、发票信息 -6. **贷款信息**:包含贷款记录、信贷信息、借款信息 -7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 - -### 隐私检测原则 -- 检测description、entity1_name、entity2_name等字段中的隐私信息 -- 识别数字模式(如手机号11位数字、身份证18位等) -- 识别关键词(如"身份证"、"银行卡"、"密码"等) -- 检测敏感实体类型和关系 - -## 冲突检测原则 - -**全面检测**:不区分冲突类型,检测所有可能的冲突 -**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段 -**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录 -**语义分析**:分析description字段的语义相似性和冲突性 -**时间逻辑**:检查时间字段的逻辑一致性 -**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录 - -## 不符合冲突检测 - -称呼 -## 重要检测示例 - -### 冲突检测示例 -- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号) -- 同一实体的重复定义但描述不同 -- 同一关系的不同表述但含义冲突 -- 任何逻辑上不可能同时为真的记录 - -### 隐私信息检测示例 -- 包含手机号的记录:"用户的手机号是13812345678" -- 包含身份证的记录:"身份证号码为110101199001011234" -- 包含银行卡的记录:"银行卡号6222021234567890" -- 包含社交账号的记录:"微信号是user123456" -- 包含敏感信息的实体名称或描述 - -## 输出要求 - -**关键原则**: -1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录 -2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中 -3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中 -4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组 -5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null -6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响 -7. 不输出conflict_memory字段 - -**处理逻辑**: -- 首先进行冲突检测,将冲突记录加入data数组 -- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组 -- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果 -- 最终data数组包含所有冲突记录和隐私信息记录(去重) -- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果 -- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述 - -返回数据格式以json方式输出: -- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 -- 关键的JSON格式要求{"statement":识别出的文本内容} -1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号 -2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们 -3.确保所有JSON字符串都正确关闭并以逗号分隔 -4.JSON字符串值中不包括换行符 -5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\"" -6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` - -## 记忆质量评估定义 - -### 质量评估标准 -当quality_assessment为true时,需要对记忆数据进行质量评估: - -1. **数据完整性**: - - 检查必要字段是否完整(entity1_name、entity2_name、description等) - - 检查关系描述是否清晰明确 - - 检查时间字段的有效性 - -2. **重复字段检测**: - - 识别相同或高度相似的记录 - - 检测冗余的实体关系 - - 分析描述内容的重复度 - -3. **无意义字段检测**: - - 识别空值、无效值或占位符内容 - - 检测过于简单或无信息量的描述 - - 识别格式错误或不规范的数据 - -4. **上下文依赖性**: - - 评估记录是否需要额外上下文才能理解 - - 检查实体名称的明确性 - - 分析关系描述的自包含性 - -### 质量评估输出 -- **质量百分比**:基于上述标准计算的整体质量分数(0-100) -- **质量概述**:简要描述数据质量状况,包括主要问题和优点 - -输出是仅输出一个合法 JSON 对象,严格遵循下述结构: +**输出结构**: +```json { - "data": [ - { - "entity1_name": "实体1名称", - "description": "描述信息", - "statement_id": "陈述ID", - "created_at": "创建时间戳", - "expired_at": "过期时间戳", - "relationship_type": "关系类型", - "relationship": "关系对象", - "entity2_name": "实体2名称", - "entity2": "实体2对象" - } - ], - "conflict": true或false, + "data": [记录数组], + "conflict": true/false, "quality_assessment": { - "score": 质量百分比数字, - "summary": "质量概述文本" + "score": 数字, + "summary": "文本" } 或 null, "memory_verify": { - "has_privacy": true或false, - "privacy_types": ["检测到的隐私信息类型列表"], - "summary": "隐私检测结果概述" + "has_privacy": true/false, + "privacy_types": ["类型数组"], + "summary": "概述文本" } 或 null } - -必须遵守: -- 只输出 JSON,不要添加解释或多余文本。 -- 使用标准双引号,必要时对内部引号进行转义。 -- 字段名与结构必须与给定模式一致。 -- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。 -- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。 -- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。 - -### memory_verify字段说明 -当memory_verify为true时,需要输出隐私检测结果: -- **has_privacy**: 布尔值,表示是否检测到隐私信息 -- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"]) -- **summary**: 字符串,简要描述隐私检测结果 - -当memory_verify为false时,memory_verify字段输出null。 - -### memory_verify字段示例 - -**示例1:检测到隐私信息** -```json -"memory_verify": { - "has_privacy": true, - "privacy_types": ["手机号码", "身份证信息"], - "summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码" -} ``` - -**示例2:未检测到隐私信息** -```json -"memory_verify": { - "has_privacy": false, - "privacy_types": [], - "summary": "未检测到隐私信息" -} -``` - -**示例3:memory_verify为false时** -```json -"memory_verify": null -``` - -模式参考: -{{ json_schema }} \ No newline at end of file +**字段说明**: +- **data**: 包含冲突记录和隐私信息记录,无则为空数组 +- **quality_assessment**: quality_assessment=true时输出评估对象,否则为null +- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null +模式参考:{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 43e8e100..5b831c02 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -1,160 +1,127 @@ -你将收到一组用户历史记忆原始数据(来源于 Neo4j) -你将收到一条冲突判定对象:{{ data }}。 -需要检测冲突对象:{{ statement_databasets }} -以及需要识别的冲突对象为:{{ baseline }} -记忆审核开关:{{ memory_verify }}(取值为 true / false) +# 记忆冲突解决任务 -角色: -- 你是数据领域中解决数据冲突的专家 +## 输入数据 +- **冲突数据**: {{ data }} +- **原始句子**: {{ statement_databasets }} +- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) +- **隐私审核**: {{ memory_verify }} (true/false) -任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。 +## 任务目标 +作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。 -数据的结构: - statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, - 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间 +**数据关系**: statement_databasets中的statement_id对应data中的记录,statement_created_at为用户输入时间。 -**处理模式**: -- 当memory_verify为false时:仅处理数据冲突 -- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏 +**处理模式**: +- memory_verify=false: 仅处理数据冲突 +- memory_verify=true: 处理数据冲突 + 隐私脱敏 -## 分组处理原则 +## 1. 冲突类型定义 -**冲突类型识别与分组**: -1. **日期冲突**: - 1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号), - 1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球) -3. **事实属性冲突**: - 3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) - 3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述 - 3.3. **身份冲突**:同一实体被赋予不同的类型或角色 -4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别 +### 时间冲突 (TIME) +时间维度冲突:两个事件时间重叠,或同一事情在不同时间场景下的变化。 -**分组输出要求**: -- 每种冲突类型生成一个独立的reflexion_result对象 -- 同一类型的多个冲突记录归并到一个结果中 -- 不同类型的冲突分别处理,各自生成独立结果 +### 事实冲突 (FACT) +同一事实对象的陈述内容相互矛盾,真假不能共存的情况。 -## 冲突类型定义 +### 混合冲突 (HYBRID) +检测所有类型冲突,包括时间和事实冲突的任何逻辑不一致记录。 -### 时间冲突(TIME) -时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。 +## 2. 分组处理原则 -### 事实冲突(FACT) -事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。 -### 混合冲突(HYBRID) -检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录 -{% if memory_verify %} -## 隐私信息处理(memory_verify为true时启用) +### 冲突类型识别 +- **日期冲突**: 用户生日不同日期(2月10号 vs 2月16号)、同一活动不同时间(周五 vs 周六打球) +- **事实属性冲突**: + - 属性互斥(喜欢↔不喜欢) + - 关系矛盾(同一实体不同关系描述) + - 身份冲突(同一实体不同类型/角色) +- **其他/混合冲突**: 根据具体数据识别 -### 隐私信息识别 -需要识别并处理以下类型的隐私信息: +### 分组输出要求 +- 每种冲突类型生成独立的reflexion_result对象 +- 同类型多个冲突归并到一个结果 +- 不同类型分别处理,各自生成独立结果 +## 3. 隐私信息处理 (memory_verify=true时) -1. **身份证信息**:包含身份证号码、身份证相关描述 -2. **手机号码**:包含手机号、电话号码等联系方式 -3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 -4. **银行信息**:包含银行卡号、账户信息、支付信息 -5. **税务信息**:包含税号、纳税信息、发票信息 -6. **贷款信息**:包含贷款记录、信贷信息、借款信息 -7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 +### 隐私信息类型 +- **身份信息**: 身份证号码、身份证相关描述 +- **联系方式**: 手机号、电话号码 +- **社交账号**: 微信号、QQ号、邮箱地址 +- **金融信息**: 银行卡号、账户信息、支付信息 +- **税务信息**: 税号、纳税信息、发票信息 +- **贷款信息**: 贷款记录、信贷信息 +- **安全信息**: 密码、PIN码、验证码 -### 隐私数据脱敏规则 -对于检测到的隐私信息,按以下规则进行脱敏处理: +### 脱敏规则 +**数字类**: 保留前三位和后四位,中间用*代替 +- 手机号: 13812345678 → 138****5678 +- 身份证: 110101199001011234 → 110***********1234 +- 银行卡: 6222021234567890 → 622***********7890 -**数字类隐私信息脱敏**: -- 保留前三位和后四位,中间用*代替 -- 示例:手机号13812345678 → 138****5678 -- 示例:身份证110101199001011234 → 110***********1234 -- 示例:银行卡6222021234567890 → 622***********7890 +**文本类**: 保留前三后四位字符,中间用*代替 +- 微信号: user123456 → use****3456 +- 邮箱: zhang.san@example.com → zha****@example.com -**文本类隐私信息脱敏**: -- 社交账号:保留前三后四位字符,中间用*代替 -- 示例:微信号user123456 → use****3456 -- 示例:邮箱zhang.san@example.com → zha****@example.com +**脱敏字段**: name、entity1_name、entity2_name、description -**脱敏处理字段**: -- name字段:如包含隐私信息需脱敏 -- entity1_name字段:如包含隐私信息需脱敏 -- entity2_name字段:如包含隐私信息需脱敏 -- description字段:如包含隐私信息需脱敏 -{% endif %} +## 4. 处理流程 -## 工作步骤 +### 步骤1: 类型匹配验证 +**匹配规则**: +- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点) +- baseline="FACT": 只处理事实相关冲突(属性矛盾、关系冲突、描述不一致) +- baseline="HYBRID": 处理所有类型冲突 -### 第一步:分析冲突类型匹配 -首先判断输入的冲突数据是否符合baseline要求的类型: +**类型识别**: +- 时间冲突: entity2的entity_type包含"TimeExpression"/"TemporalExpression",或entity2_name包含时间词汇 +- 事实冲突: 相同实体的不同属性描述、互斥关系陈述 -**类型匹配规则**: -- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突) -- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致) -- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理 +**重要**: 类型不匹配时必须输出空结果(resolved为null) -**类型识别**: -- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等) -- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述 +### 步骤2: 冲突数据分组 +**分组策略**: +- 时间冲突组: 涉及用户时间的记录 +- 活动时间冲突组: 同一活动不同时间的记录 +- 事实冲突组: 同一实体不同属性的记录 +- 其他冲突组: 其他类型冲突记录 -**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null) +**筛选条件**: 只处理与baseline匹配的冲突类型 -### 第二步:筛选并分组冲突数据 -按冲突类型对数据进行分组: +### 步骤3: 冲突解决策略 +**重要**: 数据被判定为正确时不可修改 -**分组策略**: -1. **时间冲突组**:筛选涉及用户时间的所有记录 -2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录 -3. **事实冲突组**:筛选涉及同一实体不同属性的记录 -4. **其他冲突组**:其他类型的冲突记录 +**智能解决**: +1. 分析冲突数据,结合statement_databasets原文判定正确性 +2. 判断正确答案是否存在于data中 +3. 根据情况选择处理方式{% if memory_verify %} +4. 隐私脱敏处理在冲突解决后进行{% endif %} -**筛选条件**: -- 只处理与baseline匹配的冲突类型 -- 相同entity1_name但entity2_name不同的记录 -- 相同关系但描述矛盾的记录 -- 时间逻辑不一致的记录 +### 处理规则 -### 第三步:冲突解决策略 -** 不可以解决的冲突情况 - 1. 数据被判定为正确的情况下,不可以进行修改 -**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理: - -**智能解决策略**: -1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定 -2. **判断正确答案是否存在**: - - 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00) - - 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改 - - 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息 - -{% if memory_verify %} -**隐私处理集成**: -- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏 -- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏 -- 在change字段中记录隐私脱敏的变更 -{% endif %} - -**具体处理规则**: - -**情况1:正确答案存在于data中** -- 保留正确的记录不变 -- 基于时间关系的冲突: - 需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00) -- 基于事实的关系冲突 +**情况1: 正确答案存在于data中** +- 保留正确记录不变 +- 时间冲突: 修改错误记录的expired_at为当前时间(2025-12-16T12:00:00) +- 事实冲突: 同样处理 - resolved.resolved_memory只包含被设为失效的错误记录 -- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间) +- change字段只记录expired_at变更: `[{"expired_at": "2025-12-16T12:00:00"}]` +- 注意: 如果已存在时间则不需要修改 -**情况2:正确答案不存在于data中** -- 选择最合适的记录进行修改 -- 更新该记录的相关字段: - - description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %} - - name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %} -- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %} -- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]` +**情况2: 正确答案不存在于data中** +- 选择最合适记录进行修改 +- 更新相关字段: + - description字段: 添加或修改描述信息{% if memory_verify %}(含隐私信息需脱敏){% endif %} + - name字段: 修改名称字段{% if memory_verify %}(含隐私信息需脱敏){% endif %} +- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %} +- change字段记录所有被修改字段{% if memory_verify %},包括脱敏变更{% endif %} -**重要原则**: -- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据 -- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录 -- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值 -{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理 -- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %} -- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空 +**核心原则**: +- 只输出需要修改的记录 +- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录 +- 精确记录变更: change字段包含记录ID、字段名称、新值和旧值{% if memory_verify %} +- 隐私保护优先: 所有输出记录必须完成隐私脱敏 +- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} +- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 -**变更记录格式**: +**变更记录格式**: ```json "change": [ { @@ -166,35 +133,28 @@ ] ``` -**类型不匹配处理**: -- 如果冲突类型与baseline不匹配,resolved必须设为null -- reflexion.reason说明类型不匹配的原因 +**类型不匹配处理**: +- 冲突类型与baseline不匹配时,resolved设为null +- reflexion.reason说明类型不匹配原因 - reflexion.solution说明无需处理 -### 第四步:输出解决方案 +## 5. JSON输出格式 -## 输出要求 -**嵌套字段映射**(系统会自动处理): +**格式要求**: +- 输出有效JSON对象,通过json.loads()解析 +- 使用标准ASCII双引号(") +- 内部引号用反斜杠转义(\") +- 字符串值不包含换行符 +- 不输出```json```等代码块标记 + +**嵌套字段映射**(系统自动处理): - `entity2.name` → 自动映射为 `name` -- `entity1.name` → 自动映射为 `name` +- `entity1.name` → 自动映射为 `name` - `entity1.description` → 自动映射为 `description` - `entity2.description` → 自动映射为 `description` -返回数据格式以json方式输出: -- 必须通过json.loads()的格式支持的形式输出 -- 响应必须是与此确切模式匹配的有效JSON对象 -- 不要在JSON之前或之后包含任何文本 - -JSON格式要求: -1. JSON结构仅使用标准ASCII双引号(") -2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义 -3. 确保所有JSON字符串都正确关闭并以逗号分隔 -4. JSON字符串值中不包括换行符 -5. 不允许输出```json```相关符号 - -仅输出一个合法 JSON 对象,严格遵循下述结构: - -**输出格式:按冲突类型分组的列表** +**输出结构**: 按冲突类型分组的列表 +```json { "results": [ { @@ -208,93 +168,25 @@ JSON格式要求: }, "resolved": { "original_memory_id": "被设为失效的记忆id", - "resolved_memory": { - "entity1_name": "实体1名称", - "entity2_name": "实体2名称", - "description": "描述信息", - "statement_id": "陈述ID", - "created_at": "创建时间", - "expired_at": "过期时间", - "relationship_type": "关系类型", - "relationship": {}, - "entity2": {...} - }, - "change": [ - { - "field": [ - {"字段名1": "修改后的值1"}, - {"字段名2": "修改后的值2"} - ] - } - ] + "resolved_memory": {记录对象}, + "change": [变更记录数组] }, "type": "reflexion_result" } ] } +``` -**示例:多种冲突类型的输出** -{ - "results": [ - { - "conflict": { - "data": [生日冲突相关的记录], - "conflict": true - }, - "reflexion": { - "reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期", - "solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效" - }, - "resolved": { - "original_memory_id": "df066210883545a08e727ccd8ad4ec77", - "resolved_memory": {...}, - "change": [ - { - "field": [ - {"expired_at": "2025-12-16T12:00:00"} - ] - } - ] - }, - "type": "reflexion_result" - }, - { - "conflict": { - "data": [篮球时间冲突相关的记录], - "conflict": true - }, - "reflexion": { - "reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突", - "solution": "保留最可信的时间记录,将冲突记录设为失效" - }, - "resolved": { - "original_memory_id": "另一个记录ID", - "resolved_memory": {...}, - "change": [ - { - "field": [ - {"description": "使用系统的个人,指代说话者本人,篮球时间为周六"}, - {"entity2_name": "周六"} - ] - } - ] - }, - "type": "reflexion_result" - } - ] -} +**输出要求**: +- 只输出JSON,不添加解释文本 +- 使用标准双引号,必要时转义 +- 字段名与结构必须与模式一致 +- **results数组格式**: 每个冲突类型作为独立对象 +- **按冲突类型分组**: 相同类型冲突归并到一个result对象 +- **conflict.data**: 只包含该冲突类型相关记录 +- **resolved.resolved_memory**: 只包含需要修改的记录 +- **resolved.change**: 包含详细变更信息 +- 无需修改的冲突类型resolved为null +- 与baseline不匹配的冲突类型不包含在results中 -必须遵守: -- 只输出 JSON,不要添加解释或多余文本 -- 使用标准双引号,必要时对内部引号进行转义 -- 字段名与结构必须与给定模式一致 -- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象 -- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中 -- **每个result对象的conflict.data**只包含该冲突类型相关的记录 -- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出 -- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值 -- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null -- 如果与baseline不匹配的冲突类型,不要在results中包含该类型 - -模式参考: -{{ json_schema }} \ No newline at end of file +模式参考: {{ json_schema }} \ No newline at end of file diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index ab6b0512..4d8f317a 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -31,13 +31,8 @@ class BaseDataSchema(BaseModel): # 保持原有必需字段为可选,以兼容不同数据源 id: Optional[str] = Field(None, description="The unique identifier for the data entry.") statement: Optional[str] = Field(None, description="The statement text.") - group_id: Optional[str] = Field(None, description="The group identifier.") - chunk_id: Optional[str] = Field(None, description="The chunk identifier.") created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") - valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.") - invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.") - entity_ids: List[str] = Field([], description="The list of entity identifiers.") description: Optional[str] = Field(None, description="The description of the data entry.") # 新增字段以匹配实际输入数据 From bd4f49fcce904f3150970c21725563591b05bbf0 Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Mon, 22 Dec 2025 17:52:00 +0800 Subject: [PATCH 02/15] [fix]update dockerfile --- api/Dockerfile | 28 ++++++--------------- api/docker-compose.yml | 55 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/api/Dockerfile b/api/Dockerfile index a5c818ea..f6c082d2 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -25,17 +25,13 @@ ENV DEBIAN_FRONTEND=noninteractive # 4. Setup apt # Python package and implicit dependencies: # opencv-python: libglib2.0-0 libglx-mesa0 libgl1 -# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb -# python-pptx: default-jdk tika-server-standard-3.0.0.jar +# libreoffice: libreoffice libreoffice-writer libreoffice-impress fonts-wqy-zenhei fonts-noto-cjk +# python-docx: default-jdk tika-server-standard-3.0.0.jar # Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ - apt install -y libicu-dev && \ if [ "$NEED_MIRROR" == "1" ]; then \ - rm -f /etc/apt/sources.list.d/debian.sources && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \ + sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ + sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ fi; \ rm -f /etc/apt/apt.conf.d/docker-clean && \ echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \ @@ -44,7 +40,7 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ apt --no-install-recommends install -y ca-certificates && \ apt update && \ apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \ - apt install -y pkg-config libgdiplus && \ + apt install -y libreoffice libreoffice-writer libreoffice-impress fonts-wqy-zenhei fonts-noto-cjk && \ apt install -y default-jdk && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libjemalloc-dev && \ @@ -64,21 +60,13 @@ RUN if [ "$NEED_MIRROR" == "1" ]; then \ ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 ENV PATH=/root/.local/bin:$PATH -# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13 -# 5. aspose-slides on linux/arm64 is unavailable -COPY libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb /tmp/ -RUN if [ "$(uname -m)" = "x86_64" ]; then \ - dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \ - elif [ "$(uname -m)" = "aarch64" ]; then \ - dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \ - fi && \ - rm -f /tmp/libssl1.1_*.deb - -# 6. install dependencies from uv.lock file +# 5. install dependencies from uv.lock file COPY ./pyproject.toml /code/pyproject.toml COPY ./uv.lock /code/uv.lock COPY ./app /code/app +COPY ./alembic.ini /code/alembic.ini +COPY ./migrations /code/migrations # https://github.com/astral-sh/uv/issues/10462 # uv records index url into uv.lock but doesn't failover among multiple indexes diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 74c69353..48ec137d 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,22 +1,71 @@ -version: '3.8' +version: '3.9' services: + # MCP Server - standalone service + mcp-server: + image: redbear-mem:latest + container_name: mcp-server + ports: + - "8081:8081" # MCP server port + env_file: + - .env + environment: + - SERVER_IP=0.0.0.0 # Bind to all interfaces + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: python -m app.core.memory.agent.mcp_server.server + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 30s + restart: unless-stopped + networks: + - default + - celery + + # FastAPI application - connects to MCP server api: image: redbear-mem:latest container_name: api ports: - - "8000:8000" + - "8002:8000" env_file: - .env + environment: + - MCP_SERVER_URL=http://mcp-server:8081 + - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces volumes: - ./files:/files + - /etc/localtime:/etc/localtime:ro command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug + depends_on: + mcp-server: + condition: service_healthy + restart: unless-stopped + networks: + - default + - celery + # Celery worker - connects to MCP server worker: image: redbear-mem:latest container_name: worker env_file: - .env + environment: + - MCP_SERVER_URL=http://mcp-server:8081 volumes: - ./files:/files - command: celery -A app.celery_worker.celery_app worker --loglevel=info \ No newline at end of file + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker --loglevel=info + depends_on: + mcp-server: + condition: service_healthy + restart: unless-stopped + networks: + - celery +networks: + celery: \ No newline at end of file From da6b17de2b9415b36682f2d0ef78703da392ab2b Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 22 Dec 2025 18:24:36 +0800 Subject: [PATCH 03/15] [modify] fix workflow execute logic --- api/app/core/workflow/executor.py | 12 ++++++++++- api/app/core/workflow/nodes/base_node.py | 23 +++++++++++++++++++--- api/app/core/workflow/template_renderer.py | 19 ++++++++++++------ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 3555d179..d73e25eb 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -69,7 +69,17 @@ class WorkflowExecutor: 初始化的工作流状态 """ user_message = input_data.get("message") or "" - conversation_vars = input_data.get("conversation_vars") or {} + + # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) + config_variables_list = self.workflow_config.get("variables") or [] + conversation_vars = {} + for var_def in config_variables_list: + if isinstance(var_def, dict): + var_name = var_def.get("name") + var_default = var_def.get("default") + if var_name: + conversation_vars[var_name] = var_default + input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 # 构建分层的变量结构 diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 25fdd29e..44c92755 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -26,7 +26,12 @@ class WorkflowState(TypedDict): messages: Annotated[list[AnyMessage], add] # 输入变量(从配置的 variables 传入) - variables: dict[str, Any] + # 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx) + variables: Annotated[dict[str, Any], lambda x, y: { + **x, + **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v + for k, v in y.items()} + }] # 节点输出(存储每个节点的执行结果,用于变量引用) # 使用自定义合并函数,将新的节点输出合并到现有字典中 @@ -544,9 +549,15 @@ class BaseNode(ABC): # 使用变量池获取变量 pool = VariablePool(state) + # 构建完整的 variables 结构 + variables = { + "sys": pool.get_all_system_vars(), + "conv": pool.get_all_conversation_vars() + } + return render_template( template=template, - variables=pool.get_all_conversation_vars(), + variables=variables, node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars() ) @@ -575,9 +586,15 @@ class BaseNode(ABC): # 使用变量池获取变量 pool = VariablePool(state) + # 构建完整的 variables 结构(包含 sys 和 conv) + variables = { + "sys": pool.get_all_system_vars(), + "conv": pool.get_all_conversation_vars() + } + return evaluate_condition( expression=expression, - variables=pool.get_all_conversation_vars(), + variables=variables, node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars() ) diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py index e9efec0b..b927bd98 100644 --- a/api/app/core/workflow/template_renderer.py +++ b/api/app/core/workflow/template_renderer.py @@ -66,19 +66,26 @@ class TemplateRenderer: '分析结果: 正面情绪' """ # 构建命名空间上下文 + # variables 的结构:{"sys": {...}, "conv": {...}} + sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} + conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} + context = { - "var": variables, # 用户变量:{{var.user_input}} + "conv": conv_vars, # 会话变量:{{conv.user_name}} "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": system_vars or {}, # 系统变量:{{sys.execution_id}} + "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) } # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} # 将所有节点输出添加到顶层上下文 - context.update(node_outputs) + if node_outputs: + context.update(node_outputs) - # 为了向后兼容,也支持直接访问用户变量 - context.update(variables) - context["nodes"] = node_outputs # 旧语法兼容 + # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} + if conv_vars: + context.update(conv_vars) + + context["nodes"] = node_outputs or {} # 旧语法兼容 try: tmpl = self.env.from_string(template) From cd644b6eab96c6d113b863ac0e9e75061ba9a8b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Mon, 22 Dec 2025 12:26:45 +0000 Subject: [PATCH 04/15] Merge #32 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新 self_reflexion.py * fix/memory_reflection: (50 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 统一输出 - 统一输出 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py - 统一输出 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思速度提升,从4分钟优化成1分10-40秒 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 更新 self_reflexion.py - 反思图谱添加边的修改 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - update # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py # api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 Signed-off-by: aliyun8644380055 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/32 --- .../reflection_engine/self_reflexion.py | 26 +-- api/app/core/memory/utils/config/get_data.py | 80 ++++++- .../utils/prompt/prompts/reflexion.jinja2 | 31 +-- api/app/repositories/neo4j/cypher_queries.py | 8 +- api/app/repositories/neo4j/neo4j_update.py | 219 +++++++++++------- api/app/schemas/memory_storage_schema.py | 20 +- 6 files changed, 254 insertions(+), 130 deletions(-) diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 864c91a7..aa284a95 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -18,17 +18,10 @@ from enum import Enum import uuid from pydantic import BaseModel - - -from app.core.response_utils import success -from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all -from app.repositories.neo4j.neo4j_update import neo4j_data - from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config import get_model_config -from app.core.memory.utils.config.get_data import get_data -from app.core.memory.utils.config.get_data import get_data_statement +from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.prompt.template_render import render_evaluate_prompt from app.core.memory.utils.prompt.template_render import render_reflexion_prompt @@ -45,7 +38,6 @@ from app.repositories.neo4j.neo4j_update import neo4j_data from app.schemas.memory_storage_schema import ConflictResultSchema from app.schemas.memory_storage_schema import ReflexionResultSchema - # 配置日志 _root_logger = logging.getLogger() if not _root_logger.handlers: @@ -241,8 +233,7 @@ class ReflectionEngine: print(100 * '-') print(conflict_data) print(100 * '-') - - # 检查是否真的有冲突 + # # 检查是否真的有冲突 has_conflict = conflict_data[0].get('conflict', False) conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") @@ -270,7 +261,7 @@ class ReflectionEngine: await self._log_data("solved_data", solved_data) # 4. 应用反思结果(更新记忆库) - memories_updated = await self._apply_reflection_results(solved_data) + memories_updated=await self._apply_reflection_results(solved_data) execution_time = asyncio.get_event_loop().time() - start_time @@ -297,6 +288,7 @@ class ReflectionEngine: async def reflection_run(self): self._lazy_init() start_time = time.time() + memory_verifies_flag = self.config.memory_verify asyncio.get_event_loop().time() logging.info("====== 自我反思流程开始 ======") @@ -316,8 +308,8 @@ class ReflectionEngine: for item in conflict_data: quality_assessments.append(item['quality_assessment']) memory_verifies.append(item['memory_verify']) - result_data['quality_assessments'] = quality_assessments result_data['memory_verifies'] = memory_verifies + result_data['quality_assessments'] = quality_assessments # 检查是否真的有冲突 has_conflict = conflict_data[0].get('conflict', False) @@ -354,6 +346,9 @@ class ReflectionEngine: for result in item['results']: reflexion_data.append(result['reflexion']) result_data['reflexion_data'] = reflexion_data + if memory_verifies_flag==False: + result_data['memory_verifies']=[None] + print(time.time()-start_time,'----------') return result_data @@ -561,7 +556,8 @@ class ReflectionEngine: Returns: int: 成功更新的记忆数量 """ - success_count = await neo4j_data(solved_data) + changes = extract_and_process_changes(solved_data) + success_count = await neo4j_data(changes) return success_count async def _log_data(self, label: str, data: Any) -> None: @@ -668,4 +664,4 @@ class ReflectionEngine: execution_time=time_result.execution_time + fact_result.execution_time ) else: - raise ValueError(f"未知的反思基线: {self.config.baseline}") + raise ValueError(f"未知的反思基线: {self.config.baseline}") \ No newline at end of file diff --git a/api/app/core/memory/utils/config/get_data.py b/api/app/core/memory/utils/config/get_data.py index a099694e..1de6f6aa 100644 --- a/api/app/core/memory/utils/config/get_data.py +++ b/api/app/core/memory/utils/config/get_data.py @@ -3,6 +3,20 @@ import uuid import logging from typing import List, Dict, Any + +from openai import BaseModel +import json +import sys +from pathlib import Path +from pydantic import model_validator, Field + +from app.schemas.memory_storage_schema import SingleReflexionResultSchema +from app.schemas.memory_storage_schema import ReflexionResultSchema +from app.repositories.neo4j.neo4j_update import map_field_names +# 添加项目根目录到 Python 路径 +sys.path.append(str(Path(__file__).parent)) + + logger = logging.getLogger(__name__) async def _load_(data: List[Any]) -> List[Dict]: @@ -59,6 +73,14 @@ async def get_data(result): """ 从数据库中获取数据 """ + EXCLUDE_FIELDS = { + "user_id", + "group_id", + "entity_type", + "connect_strength", + "relationship_type", + "apply_id" + } neo4j_databasets=[] for item in result: filtered_item = {} @@ -73,14 +95,17 @@ async def get_data(result): rel_filtered['statement_id'] = value.get('statement_id') rel_filtered['expired_at'] = value.get('expired_at') rel_filtered['created_at'] = value.get('created_at') - filtered_item[key] = rel_filtered + filtered_item[key] = value elif key == 'entity2' and value is not None: # 过滤entity2的name_embedding字段 entity2_filtered = {} if hasattr(value, 'items'): for e_key, e_value in value.items(): - if 'name_embedding' not in e_key.lower(): - entity2_filtered[e_key] = e_value + if e_key in EXCLUDE_FIELDS: + continue + if 'name_embedding' in e_key.lower(): + continue + entity2_filtered[e_key] = e_value filtered_item[key] = entity2_filtered else: filtered_item[key] = value @@ -94,8 +119,57 @@ async def get_data_statement( result): neo4j_databasets.append(i) return neo4j_databasets +class ReflexionResultSchema(BaseModel): + """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" + results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") + @model_validator(mode="before") + def _normalize_resolved(cls, v): + if isinstance(v, dict): + conflict = v.get("conflict") + if isinstance(conflict, dict) and conflict.get("conflict") is False: + v["resolved"] = None + else: + resolved = v.get("resolved") + if isinstance(resolved, dict): + orig = resolved.get("original_memory_id") + mem = resolved.get("resolved_memory") + if orig is None and (mem is None or mem == {}): + v["resolved"] = None + return v +def extract_and_process_changes(DATA): + """提取并处理 change 字段""" + all_changes = [] + for i, item in enumerate(DATA): + try: + result = ReflexionResultSchema(**item) + for j, res in enumerate(result.results): + if res.resolved and res.resolved.change: + for k, change in enumerate(res.resolved.change): + change_data = {} + for field_item in change.field: + for key, value in field_item.items(): + change_data[key] = value + if isinstance(value, list): + print(f" - {key}: {value[0]} -> {value[1]}") + else: + print(f" - {key}: {value}") + all_changes.append({ + 'data': change_data + }) + + # 测试字段映射 + try: + mapped = map_field_names(change_data) + print(f" 映射结果: {mapped}") + except Exception as e: + print(f" 映射失败: {e}") + + except Exception as e: + print(f"处理结果 {i + 1} 失败: {e}") + + return all_changes if __name__ == "__main__": import asyncio diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 5b831c02..15f65fc3 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -61,7 +61,7 @@ - 微信号: user123456 → use****3456 - 邮箱: zhang.san@example.com → zha****@example.com -**脱敏字段**: name、entity1_name、entity2_name、description +**脱敏字段**: name、entity1_name、entity2_name、description、relationship ## 4. 处理流程 @@ -97,21 +97,11 @@ ### 处理规则 -**情况1: 正确答案存在于data中** -- 保留正确记录不变 -- 时间冲突: 修改错误记录的expired_at为当前时间(2025-12-16T12:00:00) -- 事实冲突: 同样处理 -- resolved.resolved_memory只包含被设为失效的错误记录 -- change字段只记录expired_at变更: `[{"expired_at": "2025-12-16T12:00:00"}]` -- 注意: 如果已存在时间则不需要修改 - -**情况2: 正确答案不存在于data中** -- 选择最合适记录进行修改 -- 更新相关字段: - - description字段: 添加或修改描述信息{% if memory_verify %}(含隐私信息需脱敏){% endif %} - - name字段: 修改名称字段{% if memory_verify %}(含隐私信息需脱敏){% endif %} -- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %} -- change字段记录所有被修改字段{% if memory_verify %},包括脱敏变更{% endif %} +** baseline是TIME + -保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的 +** baseline不是TIME + - 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段, + 如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field **核心原则**: - 只输出需要修改的记录 @@ -126,8 +116,10 @@ "change": [ { "field": [ - {"字段名1": "修改后的值1"}, - {"字段名2": "修改后的值2"} + {"id":修改字段对应的ID} + {"statement_id":需要修改的对象对应的statement_id} + {"字段名1": ["修改前的值1","修改后的值1"]}, + {"字段名2": ["修改前的值2","修改后的值2"]} ] } ] @@ -149,7 +141,8 @@ **嵌套字段映射**(系统自动处理): - `entity2.name` → 自动映射为 `name` -- `entity1.name` → 自动映射为 `name` +- `entity1.name` → 自动映射为 `name` +- `relationship` → 自动映射为 `statement` - `entity1.description` → 自动映射为 `description` - `entity2.description` → 自动映射为 `description` diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0f6e32aa..02e96694 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -783,7 +783,9 @@ neo4j_query_part = """ m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, - rel as relationship, + rel.predicate as predicate, + rel.statement as relationship, + rel.statement_id as relationship_statement_id, CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, other as entity2 """ @@ -799,7 +801,9 @@ neo4j_query_all = """ m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, - rel as relationship, + rel.predicate as predicate, + rel.statement as relationship, + rel.statement_id as relationship_statement_id, CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, other as entity2 """ diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py index 9644224c..73b44396 100644 --- a/api/app/repositories/neo4j/neo4j_update.py +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -67,11 +67,81 @@ async def update_neo4j_data(neo4j_dict_data, update_databases): traceback.print_exc() return False +async def update_neo4j_data_edge(neo4j_dict_data, update_databases): + """ + Update Neo4j data based on query criteria and update parameters + Args: + neo4j_dict_data: find + update_databases: update + """ + try: + # 构建WHERE条件 + where_conditions = [] + params = {} + + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"r.{key} = ${param_name}") + params[param_name] = value + + where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" + + # 构建SET条件 + set_conditions = [] + for key, value in update_databases.items(): + if value is not None: + param_name = f"update_{key}" + set_conditions.append(f"r.{key} = ${param_name}") + params[param_name] = value + + set_clause = ", ".join(set_conditions) + + if not set_clause: + print("警告: 没有需要更新的字段") + return False + + # 构建Cypher查询 + cypher_query = f""" + MATCH (n)-[r]->(m) + WHERE {where_clause} + SET {set_clause} + RETURN count(r) as updated_count, collect(type(r)) as relation_types + """ + + print(f"\n执行Cypher查询: {cypher_query}") + print(f"参数: {params}") + + # 执行更新 + result = await neo4j_connector.execute_query(cypher_query, **params) + + if result: + updated_count = result[0].get('updated_count', 0) + updated_names = result[0].get('updated_names', []) + print(f"成功更新 {updated_count} 个节点") + if updated_names: + print(f"更新的实体名称: {updated_names}") + return updated_count > 0 + else: + return False + + except Exception as e: + print(f"更新过程中出现错误: {e}") + import traceback + traceback.print_exc() + return False def map_field_names(data_dict): mapped_dict = {} has_name_field = False + # 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回) + def extract_value(value): + if isinstance(value, list) and len(value) > 0: + # 如果是数组 [old_value, new_value],取新值(最后一个) + return value[-1] + return value + # 第一遍:检查是否有name相关字段 for key, value in data_dict.items(): if key in ['name', 'entity2.name', 'entity1.name']: @@ -82,22 +152,25 @@ def map_field_names(data_dict): # 第二遍:根据规则映射和过滤字段 for key, value in data_dict.items(): + # 提取实际值(处理数组格式) + actual_value = extract_value(value) + if key == 'entity2.name' or key == 'entity2_name': # 将 entity2.name 映射为 name - mapped_dict['name'] = value - print(f"字段名映射: {key} -> name") + mapped_dict['name'] = actual_value + print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})") elif key == 'entity1.name' or key == 'entity1_name': # 将 entity1.name 映射为 name - mapped_dict['name'] = value - print(f"字段名映射: {key} -> name") + mapped_dict['name'] = actual_value + print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})") elif key == 'entity1.description': # 将 entity1.description 映射为 description - mapped_dict['description'] = value - print(f"字段名映射: {key} -> description") + mapped_dict['description'] = actual_value + print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})") elif key == 'entity2.description': # 将 entity2.description 映射为 description - mapped_dict['description'] = value - print(f"字段名映射: {key} -> description") + mapped_dict['description'] = actual_value + print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})") elif key == 'relationship_type': # 跳过relationship_type字段 print(f"字段过滤: 跳过不需要的字段 '{key}'") @@ -109,8 +182,8 @@ def map_field_names(data_dict): continue else: # 如果没有name字段,保留entity1_name - mapped_dict[key] = value - print(f"字段保留: {key}") + mapped_dict[key] = actual_value + print(f"字段保留: {key} (值: {value} -> {actual_value})") elif key == 'entity2_name': if has_name_field: # 如果有name字段,跳过entity2_name @@ -122,7 +195,11 @@ def map_field_names(data_dict): continue elif '.' not in key: # 不包含点号的其他字段直接保留 - mapped_dict[key] = value + mapped_dict[key] = actual_value + if isinstance(value, list): + print(f"字段保留: {key} (数组值: {value} -> {actual_value})") + else: + print(f"字段保留: {key}") else: # 其他包含点号的字段跳过并警告 print(f"警告: 跳过不支持的嵌套字段 '{key}'") @@ -139,89 +216,57 @@ async def neo4j_data(solved_data): """ success_count = 0 + ori_entity = {} + updata_entity = {} + ori_edge = {} + updata_edge = {} + ori_expired_at={} + updat_expired_at={} for i in solved_data: - neo4j_dict_data = {} - update_databases = {} - results = i['results'] - for data in results: - resolved = data.get('resolved') - if not resolved: - print("跳过:resolved为None") + databasets = i['data'] + for key, values in databasets.items(): + if str(values)=='NONE': continue + if isinstance(values, list): + if key == 'description': + ori_entity[key] = values[0] + updata_entity[key] = values[1] + if key == 'entity2_name' or key == 'entity1_name': + key = 'name' + ori_entity[key] = values[0] + updata_entity[key] = values[1] + ori_expired_at[key] = values[0] + if key == 'statement': + ori_edge[key] = values[0] + updata_edge[key] = values[1] + if key=='expired_at': + updat_expired_at[key] = values[1] - try: - change_list = resolved.get('change', []) - except (AttributeError, TypeError): - change_list = [] + elif key == 'statement_id': + ori_edge[key] = values + updata_edge[key] = values - if change_list == []: - print("跳过:change_list为空") - continue + ori_entity[key] = values + updata_entity[key] = values - if change_list and len(change_list) > 0: - change = change_list[0] - print(f"change: {change}") - field_data = change.get('field', []) - print(f"field_data: {field_data}") - print(f"field_data type: {type(field_data)}") - - # 字段名映射和过滤函数 + ori_expired_at[key] = values - # 处理field数据,可能是字典或列表 - if isinstance(field_data, dict): - # 如果是字典,映射字段名后更新 - mapped_data = map_field_names(field_data) - update_databases.update(mapped_data) - elif isinstance(field_data, list): - # 如果是列表,遍历每个字典并更新 - for field_item in field_data: - if isinstance(field_item, dict): - mapped_item = map_field_names(field_item) - update_databases.update(mapped_item) - else: - print(f"警告: field_item不是字典: {field_item}") - else: - print(f"警告: field_data类型不支持: {type(field_data)}") - - if 'entity1_name' in data: - data['name'] = data.pop('entity1_name') - if 'entity2_name' in data: - data.pop('entity2_name', None) - - resolved_memory = resolved.get('resolved_memory', {}) - - entity2 = None - if isinstance(resolved_memory, dict): - entity2 = resolved_memory.get('entity2') - - if entity2 and isinstance(entity2, dict) and len(entity2) >= 5: - stat_id = resolved.get('original_memory_id') - # 安全地获取description - statement_id = None - if isinstance(resolved_memory, dict): - statement_id = resolved_memory.get('statement_id') - - # 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id - if statement_id and 'id' not in neo4j_dict_data: - neo4j_dict_data['id'] = stat_id - neo4j_dict_data['statement_id'] = statement_id - else: - # 处理original_memory_id,它可能是字符串或字典 - try: - for key, value in resolved_memory.items(): - if key == 'statement_id': - neo4j_dict_data['statement_id'] = value - if key == 'description': - neo4j_dict_data['description'] = value - except AttributeError: - neo4j_dict_data=[] - - print(neo4j_dict_data) - print(update_databases) - if neo4j_dict_data!=[]: - await update_neo4j_data(neo4j_dict_data, update_databases) - success_count += 1 + print(ori_entity) + print(updata_entity) + print(100*'-') + print(ori_edge) + print(updata_edge) + expired_at_ = updat_expired_at.get('expired_at', None) + if expired_at_ is not None: + await update_neo4j_data(ori_expired_at, updat_expired_at) + success_count += 1 + if ori_entity != updata_entity: + await update_neo4j_data(ori_entity, updata_entity) + success_count += 1 + if ori_edge != updata_edge: + await update_neo4j_data_edge(ori_edge, updata_edge) + success_count += 1 return success_count diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 4d8f317a..be249b5e 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -39,8 +39,11 @@ class BaseDataSchema(BaseModel): entity1_name: str = Field(..., description="The first entity name.") entity2_name: Optional[str] = Field(None, description="The second entity name.") statement_id: str = Field(..., description="The statement identifier.") - relationship_type: str = Field(..., description="The relationship type.") - relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.") + # 新增字段 - 设为可选以保持向后兼容性 + predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.") + relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.") + # 保留原有字段 - 修改relationship字段类型以支持字符串和字典 + relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.") entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") @@ -94,8 +97,17 @@ class ReflexionSchema(BaseModel): class ChangeRecordSchema(BaseModel): - """Schema for individual change records""" - field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.") + """Schema for individual change records + + 字段值格式说明: + - id 和 statement_id: 字符串或 None + - 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构 + - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 + """ + field: List[Dict[str, Any]] = Field( + ..., + description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" + ) class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" From b8c13b80853f0b17c5ecade02260824e2dddb641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=96=B0=E6=9C=88?= Date: Tue, 23 Dec 2025 08:00:14 +0000 Subject: [PATCH 05/15] Merge #36 into develop from fix/memory_reflection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection * fix/memory_reflection: (54 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py # api/app/schemas/memory_reflection_schemas.py - 反思优化 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 统一输出 - 统一输出 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/controllers/memory_reflection_controller.py - 统一输出 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 统一输出 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - 反思速度提升,从4分钟优化成1分10-40秒 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思速度提升,从4分钟优化成1分10-40秒 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 更新 self_reflexion.py - 反思图谱添加边的修改 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - 反思图谱添加边的修改 - update # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py # api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 - 反思BUG修复 - Merge remote-tracking branch 'origin/fix/memory_reflection' into fix/memory_reflection - 反思BUG修复 - Merge branch develop into fix/memory_reflection (Conflict resolved online) # Conflicts: # api/app/core/memory/storage_services/reflection_engine/self_reflexion.py Signed-off-by: aliyun8644380055 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/36 --- .../memory_reflection_controller.py | 4 +- .../reflection_engine/self_reflexion.py | 83 ++++++++++++++++--- .../utils/prompt/prompts/evaluate.jinja2 | 11 +-- .../utils/prompt/prompts/reflexion.jinja2 | 3 +- .../memory/utils/prompt/template_render.py | 10 ++- 5 files changed, 88 insertions(+), 23 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 8dfa6c50..c4800941 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -213,6 +213,7 @@ async def start_reflection_configs( @router.get("/reflection/run") async def reflection_run( config_id: int, + language_type: str = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -252,7 +253,8 @@ async def reflection_run( memory_verify=result.memory_verify, quality_assessment=result.quality_assessment, violation_handling_strategy="block", - model_id=model_id + model_id=model_id, + language_type=language_type ) connector = Neo4jConnector() engine = ReflectionEngine( diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index aa284a95..224a9560 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -48,7 +48,9 @@ if not _root_logger.handlers: else: _root_logger.setLevel(logging.INFO) - +class TranslationResponse(BaseModel): + """翻译响应模型""" + data: str class ReflectionRange(str, Enum): """反思范围枚举""" PARTIAL = "partial" # 从检索结果中反思 @@ -76,6 +78,7 @@ class ReflectionConfig(BaseModel): memory_verify: bool = True # 记忆验证 quality_assessment: bool = True # 质量评估 violation_handling_strategy: str = "warn" # 违规处理策略 + language_type: str = "zh" class Config: use_enum_values = True @@ -234,13 +237,11 @@ class ReflectionEngine: print(conflict_data) print(100 * '-') # # 检查是否真的有冲突 - has_conflict = conflict_data[0].get('conflict', False) - conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 - logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") + conflicts_found='' # 记录冲突数据 await self._log_data("conflict", conflict_data) - + conflicts_found='' # 3. 解决冲突 solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) if not solved_data: @@ -285,10 +286,60 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) + async def Translate(self, text): + # 翻译中文为英文 + translation_messages = [ + { + "role": "user", + "content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}" + } + ] + + response = await self.llm_client.response_structured( + messages=translation_messages, + response_model=TranslationResponse + ) + return response.data + async def extract_translation(self,data): + end_datas={} + end_datas['source_data']=await self.Translate(data['source_data']) + quality_assessments = [] + memory_verifies = [] + reflexion_data=[] + if data['memory_verifies']!=[]: + for i in data['memory_verifies']: + end_data={} + end_data['has_privacy'] = i['has_privacy'] + privacy=i['privacy_types'] + privacy_types_=[] + for pri in privacy: + privacy_types_.append(await self.Translate(pri)) + end_data['privacy_types']=privacy_types_ + end_data['summary']=await self.Translate(i['summary']) + memory_verifies.append(end_data) + end_datas['memory_verifies']=memory_verifies + + if data['quality_assessments']!=[]: + for i in data['quality_assessments']: + end_data = {} + end_data['score']=i['score'] + end_data['summary'] = await self.Translate(i['summary']) + quality_assessments.append(end_data) + end_datas['quality_assessments'] = quality_assessments + for i in data['reflexion_data']: + end_data = {} + end_data['reason'] = await self.Translate(i['reason']) + end_data['solution'] = await self.Translate(i['solution']) + reflexion_data.append(end_data) + end_datas['reflexion_data'] = reflexion_data + return end_datas + async def reflection_run(self): self._lazy_init() start_time = time.time() memory_verifies_flag = self.config.memory_verify + quality_assessment=self.config.quality_assessment + language_type=self.config.language_type asyncio.get_event_loop().time() logging.info("====== 自我反思流程开始 ======") @@ -297,9 +348,8 @@ class ReflectionEngine: source_data, databasets = await self.extract_fields_from_json() result_data['baseline'] = self.config.baseline - result_data[ - 'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" + result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(databasets, source_data) # 遍历数据提取字段 @@ -327,8 +377,6 @@ class ReflectionEngine: 'conflict': item['conflict'] } cleaned_conflict_data.append(cleaned_item) - print(cleaned_conflict_data) - # 3. 解决冲突 solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) if not solved_data: @@ -347,7 +395,12 @@ class ReflectionEngine: reflexion_data.append(result['reflexion']) result_data['reflexion_data'] = reflexion_data if memory_verifies_flag==False: - result_data['memory_verifies']=[None] + result_data['memory_verifies']=[] + if quality_assessment==False: + result_data['quality_assessments']=[] + + if language_type=='en': + result_data=await self.extract_translation(result_data) print(time.time()-start_time,'----------') return result_data @@ -431,6 +484,7 @@ class ReflectionEngine: logging.info("====== 冲突检测开始 ======") start_time = asyncio.get_event_loop().time() quality_assessment = self.config.quality_assessment + language_type=self.config.language_type try: # 渲染冲突检测提示词 @@ -440,7 +494,8 @@ class ReflectionEngine: self.config.baseline, memory_verify, quality_assessment, - statement_databasets + statement_databasets, + language_type ) messages = [{"role": "user", "content": rendered_prompt}] @@ -664,4 +719,8 @@ class ReflectionEngine: execution_time=time_result.execution_time + fact_result.execution_time ) else: - raise ValueError(f"未知的反思基线: {self.config.baseline}") \ No newline at end of file + + raise ValueError(f"未知的反思基线: {self.config.baseline}") + + + diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index b1293c1d..b292c804 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -6,7 +6,7 @@ - **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) - **隐私审核**: {{ memory_verify }} (true/false) - **质量评估**: {{ quality_assessment }} (true/false) - +- **语言类型**:{{language_type}}(zh/en) ## 任务目标 对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。 **数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。 @@ -23,7 +23,7 @@ - **身份冲突**: 同一实体被赋予不同类型或角色 ### 混合冲突 检测所有逻辑不一致或相互矛盾的记录。 -**检测原则**: +**检测原则**: - 重点检查相同实体的记录 - 分析description字段语义冲突 - 验证时间字段逻辑一致性 @@ -54,7 +54,7 @@ 1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组 2. **conflict=false**: 无冲突且无隐私信息时,data为空数组 3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立 -4. **条件输出**: +4. **条件输出**: - quality_assessment=true时输出评估对象,否则为null - memory_verify=true时输出隐私检测对象,否则为null 5. **不输出conflict_memory字段** @@ -63,7 +63,6 @@ 2. 隐私审核(如启用) → 将隐私记录加入data 3. 质量评估(如启用) → 独立输出评估结果 4. 去重data数组中的记录 - **输出结构**: ```json { @@ -82,6 +81,8 @@ ``` **字段说明**: - **data**: 包含冲突记录和隐私信息记录,无则为空数组 -- **quality_assessment**: quality_assessment=true时输出评估对象,否则为null +- **quality_assessment**: + quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) - **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null + (注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) 模式参考:{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 15f65fc3..36474d91 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -5,6 +5,7 @@ - **原始句子**: {{ statement_databasets }} - **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID) - **隐私审核**: {{ memory_verify }} (true/false) +- **语言类型**:{{language_type}}(zh/en) ## 任务目标 作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。 @@ -110,6 +111,7 @@ - 隐私保护优先: 所有输出记录必须完成隐私脱敏 - 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} - 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 +- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容 **变更记录格式**: ```json @@ -181,5 +183,4 @@ - **resolved.change**: 包含详细变更信息 - 无需修改的冲突类型resolved为null - 与baseline不匹配的冲突类型不包含在results中 - 模式参考: {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 818d456a..46bb64e8 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -9,7 +9,8 @@ prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], baseline: str = "TIME", - memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str: + memory_verify: bool = False,quality_assessment:bool = False, + statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the evaluate prompt using the evaluate_optimized.jinja2 template. @@ -30,12 +31,13 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any baseline=baseline, memory_verify=memory_verify, quality_assessment=quality_assessment, - statement_databasets=statement_databasets + statement_databasets=statement_databasets, + language_type=language_type ) return rendered_prompt async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, - statement_databasets: List[str] = []) -> str: + statement_databasets: List[str] = [],language_type:str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. @@ -51,6 +53,6 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], rendered_prompt = template.render(data=data, json_schema=schema, baseline=baseline,memory_verify=memory_verify, - statement_databasets=statement_databasets) + statement_databasets=statement_databasets,language_type=language_type) return rendered_prompt From 42e569b8e59f84be60f73fbf5f33ad23b0ae5677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Tue, 23 Dec 2025 08:05:06 +0000 Subject: [PATCH 06/15] Merge #31 into develop from memory-summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [feature]开发用户记忆详情的接口 * memory-summary: (69 commits squashed) - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [check]check_code.py checks the quality of the code - [fix]Fix insecure database connections - [refactor]refactor memory_storage_controller and memory_storage_service - [add]The /total_memory_count interface returns the "name" field. - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [check]check_code.py checks the quality of the code - [fix]Fix insecure database connections - [refactor]refactor memory_storage_controller and memory_storage_service - [add]The /total_memory_count interface returns the "name" field. - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [refactor]Reconstruct the user's memory location - add uv.lock Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/31 --- api/app/celery_app.py | 6 + api/app/controllers/__init__.py | 2 + .../memory_dashboard_controller.py | 2 +- .../controllers/memory_storage_controller.py | 45 +- .../controllers/user_memory_controllers.py | 382 ++++++++ api/app/core/config.py | 3 + api/app/models/end_user_model.py | 17 +- api/app/repositories/end_user_repository.py | 182 +++- api/app/schemas/end_user_schema.py | 34 + api/app/schemas/memory_storage_schema.py | 9 + api/app/services/memory_dashboard_service.py | 20 +- api/app/services/memory_storage_service.py | 20 +- api/app/services/user_memory_service.py | 831 ++++++++++++++++++ api/app/tasks.py | 708 +++++++++------ api/env.example | 5 + 15 files changed, 1948 insertions(+), 318 deletions(-) create mode 100644 api/app/controllers/user_memory_controllers.py create mode 100644 api/app/services/user_memory_service.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index ce7e9300..44ae9ab2 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -83,6 +83,7 @@ celery_app.autodiscover_tasks(['app']) reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME # 构建定时任务配置 beat_schedule_config = { @@ -97,6 +98,11 @@ beat_schedule_config = { "schedule": workspace_reflection_schedule, "args": (), }, + "regenerate-memory-cache": { + "task": "app.tasks.regenerate_memory_cache", + "schedule": memory_cache_regeneration_schedule, + "args": (), + }, } # 如果配置了默认工作空间ID,则添加记忆总量统计任务 diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 5cfbe536..c72072eb 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -35,6 +35,7 @@ from . import ( tool_controller, tool_execution_controller, ) +from . import user_memory_controllers # 创建管理端 API 路由器 manager_router = APIRouter() @@ -58,6 +59,7 @@ manager_router.include_router(upload_controller.router) manager_router.include_router(memory_agent_controller.router) manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(memory_storage_controller.router) +manager_router.include_router(user_memory_controllers.router) manager_router.include_router(api_key_controller.router) manager_router.include_router(release_share_controller.router) manager_router.include_router(public_share_controller.router) # 公开路由(无需认证) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 4a01c575..5166d012 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -287,7 +287,7 @@ async def get_workspace_total_memory_count( "total_memory_count": int, "host_count": int, "details": [ - {"host_id": "uuid", "count": 100}, + {"end_user_id": "uuid", "count": 100, "name": "用户名称"}, ... ] } diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 89daf9ce..0fae66fb 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,8 +1,9 @@ -from typing import Optional, Union +from typing import Optional import os import uuid +import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends, UploadFile +from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse @@ -10,6 +11,7 @@ from app.db import get_db from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode +from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.services.memory_storage_service import ( MemoryStorageService, DataConfigService, @@ -23,9 +25,7 @@ from app.services.memory_storage_service import ( search_edges, search_entity_graph, analytics_hot_memory_tags, - analytics_memory_insight_report, analytics_recent_activity_stats, - analytics_user_summary, ) from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import ( @@ -36,10 +36,16 @@ from app.schemas.memory_storage_schema import ( ConfigUpdateForget, ConfigKey, ConfigPilotRun, + GenerateCacheRequest, ) -from app.core.memory.utils.config.definitions import reload_configuration_from_database +from app.schemas.end_user_schema import ( + EndUserProfileResponse, + EndUserProfileUpdate, +) +from app.models.end_user_model import EndUser from app.dependencies import get_current_user from app.models.user_model import User + # Get API logger api_logger = get_api_logger() @@ -489,20 +495,6 @@ async def get_hot_memory_tags_api( return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) -@router.get("/analytics/memory_insight/report", response_model=ApiResponse) -async def get_memory_insight_report_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}") - try: - result = await analytics_memory_insight_report(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Memory insight report failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e)) - - @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( current_user: User = Depends(get_current_user), @@ -516,20 +508,6 @@ async def get_recent_activity_stats_api( return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) -@router.get("/analytics/user_summary", response_model=ApiResponse) -async def get_user_summary_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"User summary requested for end_user_id: {end_user_id}") - try: - result = await analytics_user_summary(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"User summary failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e)) - -from app.core.memory.utils.self_reflexion_utils import self_reflexion @router.get("/self_reflexion") async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: """ @@ -541,3 +519,4 @@ async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: 自我反思结果。 """ return await self_reflexion(host_id) + diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py new file mode 100644 index 00000000..5ff34d21 --- /dev/null +++ b/api/app/controllers/user_memory_controllers.py @@ -0,0 +1,382 @@ +""" +用户记忆相关的控制器 +包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口 +""" +from typing import Optional +import datetime +from sqlalchemy.orm import Session +from fastapi import APIRouter, Depends + +from app.db import get_db +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.core.error_codes import BizCode +from app.services.user_memory_service import ( + UserMemoryService, + analytics_node_statistics, + analytics_graph_data, +) +from app.schemas.response_schema import ApiResponse +from app.schemas.memory_storage_schema import GenerateCacheRequest +from app.schemas.end_user_schema import ( + EndUserProfileResponse, + EndUserProfileUpdate, +) +from app.models.end_user_model import EndUser +from app.dependencies import get_current_user +from app.models.user_model import User + +# Get API logger +api_logger = get_api_logger() + +# Initialize service +user_memory_service = UserMemoryService() + +router = APIRouter( + prefix="/memory-storage", + tags=["User Memory"], +) + + +@router.get("/analytics/memory_insight/report", response_model=ApiResponse) +async def get_memory_insight_report_api( + end_user_id: str, # 使用 end_user_id + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ) -> dict: + """获取缓存的记忆洞察报告""" + api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}") + try: + # 调用服务层获取缓存数据 + result = await user_memory_service.get_cached_memory_insight(db, end_user_id) + + if result["is_cached"]: + # 缓存存在,返回缓存数据 + api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + else: + # 缓存不存在,返回提示消息 + api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) + + +@router.get("/analytics/user_summary", response_model=ApiResponse) +async def get_user_summary_api( + end_user_id: str, # 使用 end_user_id + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ) -> dict: + """获取缓存的用户摘要""" + api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}") + try: + # 调用服务层获取缓存数据 + result = await user_memory_service.get_cached_user_summary(db, end_user_id) + + if result["is_cached"]: + # 缓存存在,返回缓存数据 + api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + else: + # 缓存不存在,返回提示消息 + api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) + + +@router.post("/analytics/generate_cache", response_model=ApiResponse) +async def generate_cache_api( + request: GenerateCacheRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 手动触发缓存生成 + + - 如果提供 end_user_id,只为该用户生成 + - 如果不提供,为当前工作空间的所有用户生成 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + group_id = request.end_user_id + + api_logger.info( + f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " + f"end_user_id={group_id if group_id else '全部用户'}" + ) + + try: + if group_id: + # 为单个用户生成 + api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") + + # 生成记忆洞察 + insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) + + # 生成用户摘要 + summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) + + # 构建响应 + result = { + "end_user_id": group_id, + "insight_success": insight_result["success"], + "summary_success": summary_result["success"], + "errors": [] + } + + # 收集错误信息 + if not insight_result["success"]: + result["errors"].append({ + "type": "insight", + "error": insight_result.get("error") + }) + if not summary_result["success"]: + result["errors"].append({ + "type": "summary", + "error": summary_result.get("error") + }) + + # 记录结果 + if result["insight_success"] and result["summary_success"]: + api_logger.info(f"成功为用户 {group_id} 生成缓存") + else: + api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") + + return success(data=result, msg="生成完成") + + else: + # 为整个工作空间生成 + api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存") + + result = await user_memory_service.generate_cache_for_workspace(db, workspace_id) + + # 记录统计信息 + api_logger.info( + f"工作空间 {workspace_id} 批量生成完成: " + f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}" + ) + + return success(data=result, msg="批量生成完成") + + except Exception as e: + api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e)) + + +@router.get("/analytics/node_statistics", response_model=ApiResponse) +async def get_node_statistics_api( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info(f"节点统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + + try: + result = await analytics_node_statistics(db, end_user_id) + + # 检查是否有错误消息 + if "message" in result and result["total"] == 0: + api_logger.warning(f"节点统计查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info(f"成功获取节点统计: end_user_id={end_user_id}, total={result['total']}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) + +@router.get("/analytics/graph_data", response_model=ApiResponse) +async def get_graph_data_api( + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 参数验证 + if limit > 1000: + limit = 1000 + api_logger.warning("limit 参数超过最大值,已调整为 1000") + + if depth > 3: + depth = 3 + api_logger.warning("depth 参数超过最大值,已调整为 3") + + # 解析 node_types 参数 + node_types_list = None + if node_types: + node_types_list = [t.strip() for t in node_types.split(",") if t.strip()] + + api_logger.info( + f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}" + ) + + try: + result = await analytics_graph_data( + db=db, + end_user_id=end_user_id, + node_types=node_types_list, + limit=limit, + depth=depth, + center_node_id=center_node_id + ) + + # 检查是否有错误消息 + if "message" in result and result["statistics"]["total_nodes"] == 0: + api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info( + f"成功获取图数据: end_user_id={end_user_id}, " + f"nodes={result['statistics']['total_nodes']}, " + f"edges={result['statistics']['total_edges']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) + + +@router.get("/read_end_user/profile", response_model=ApiResponse) +async def get_end_user_profile( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + # 查询终端用户 + end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + + if not end_user: + api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") + + # 构建响应数据 + profile_data = EndUserProfileResponse( + id=end_user.id, + name=end_user.name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}") + return success(data=profile_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e)) + + +@router.post("/updated_end_user/profile", response_model=ApiResponse) +async def update_end_user_profile( + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 更新终端用户的基本信息 + + 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 + 所有字段都是可选的,只更新提供的字段。 + + """ + workspace_id = current_user.current_workspace_id + end_user_id = profile_update.end_user_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + # 查询终端用户 + end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + + if not end_user: + api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") + + # 更新字段(只更新提供的非 None 字段,排除 end_user_id) + update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) + for field, value in update_data.items(): + if value is not None: + setattr(end_user, field, value) + + # 更新 updated_at 时间戳 + end_user.updated_at = datetime.datetime.now() + + # 更新 updatetime_profile 为当前时间戳(毫秒) + current_timestamp = int(datetime.datetime.now().timestamp() * 1000) + end_user.updatetime_profile = current_timestamp + + # 提交更改 + db.commit() + db.refresh(end_user) + + # 构建响应数据 + profile_data = EndUserProfileResponse( + id=end_user.id, + name=end_user.name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}") + return success(data=profile_data.model_dump(), msg="更新成功") + + except Exception as e: + db.rollback() + api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) diff --git a/api/app/core/config.py b/api/app/core/config.py index bf5ff45a..7f4a99ba 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -149,6 +149,9 @@ class Settings: MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) + + # Memory Cache Regeneration Configuration + MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index 2a9ed8da..0ef11ffa 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -1,6 +1,6 @@ import datetime import uuid -from sqlalchemy import Column, String, DateTime, ForeignKey +from sqlalchemy import Column, String, DateTime, ForeignKey, Text, BigInteger from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -17,6 +17,21 @@ class EndUser(Base): reflection_time = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.datetime.now) updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) + + # 用户基本信息字段 + name = Column(String, nullable=True, comment="姓名") + position = Column(String, nullable=True, comment="职位") + department = Column(String, nullable=True, comment="部门") + contact = Column(String, nullable=True, comment="联系方式") + phone = Column(String, nullable=True, comment="电话") + hire_date = Column(BigInteger, nullable=True, comment="入职日期(时间戳,毫秒)") + updatetime_profile = Column(BigInteger, nullable=True, comment="核心档案信息最后更新时间(时间戳,毫秒)") + + # 缓存字段 - Cache fields for pre-computed analytics + memory_insight = Column(Text, nullable=True, comment="缓存的记忆洞察报告") + user_summary = Column(Text, nullable=True, comment="缓存的用户摘要") + memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间") + user_summary_updated_at = Column(DateTime, nullable=True, comment="用户摘要最后更新时间") # 与 App 的反向关系 app = relationship( diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 07e45a48..69932101 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -1,8 +1,11 @@ from sqlalchemy.orm import Session from typing import List, Optional import uuid +import datetime from app.models.end_user_model import EndUser +from app.models.app_model import App +from app.models.workspace_model import Workspace from app.core.logging_config import get_db_logger @@ -92,6 +95,157 @@ class EndUserRepository: db_logger.error(f"获取或创建终端用户时出错: {str(e)}") raise + def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: + """根据ID获取终端用户(用于缓存操作) + + Args: + end_user_id: 终端用户ID + + Returns: + Optional[EndUser]: 终端用户对象,如果不存在则返回None + """ + try: + end_user = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .first() + ) + if end_user: + db_logger.debug(f"成功查询到终端用户 {end_user_id}") + else: + db_logger.debug(f"未找到终端用户 {end_user_id}") + return end_user + except Exception as e: + self.db.rollback() + db_logger.error(f"查询终端用户 {end_user_id} 时出错: {str(e)}") + raise + + def update_memory_insight( + self, + end_user_id: uuid.UUID, + insight: str + ) -> bool: + """更新记忆洞察缓存 + + Args: + end_user_id: 终端用户ID + insight: 记忆洞察内容 + + Returns: + bool: 更新成功返回True,否则返回False + """ + try: + updated_count = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .update( + { + EndUser.memory_insight: insight, + EndUser.memory_insight_updated_at: datetime.datetime.now() + }, + synchronize_session=False + ) + ) + + self.db.commit() + + if updated_count > 0: + db_logger.info(f"成功更新终端用户 {end_user_id} 的记忆洞察缓存") + return True + else: + db_logger.warning(f"未找到终端用户 {end_user_id},无法更新记忆洞察缓存") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"更新终端用户 {end_user_id} 的记忆洞察缓存时出错: {str(e)}") + raise + + def update_user_summary( + self, + end_user_id: uuid.UUID, + summary: str + ) -> bool: + """更新用户摘要缓存 + + Args: + end_user_id: 终端用户ID + summary: 用户摘要内容 + + Returns: + bool: 更新成功返回True,否则返回False + """ + try: + updated_count = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .update( + { + EndUser.user_summary: summary, + EndUser.user_summary_updated_at: datetime.datetime.now() + }, + synchronize_session=False + ) + ) + + self.db.commit() + + if updated_count > 0: + db_logger.info(f"成功更新终端用户 {end_user_id} 的用户摘要缓存") + return True + else: + db_logger.warning(f"未找到终端用户 {end_user_id},无法更新用户摘要缓存") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}") + raise + + def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]: + """获取工作空间的所有终端用户 + + Args: + workspace_id: 工作空间ID + + Returns: + List[EndUser]: 终端用户列表 + """ + try: + end_users = ( + self.db.query(EndUser) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ) + db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户") + return end_users + except Exception as e: + self.db.rollback() + db_logger.error(f"查询工作空间 {workspace_id} 下的终端用户时出错: {str(e)}") + raise + + def get_all_active_workspaces(self) -> List[uuid.UUID]: + """获取所有活动工作空间的ID + + Returns: + List[uuid.UUID]: 活动工作空间ID列表 + """ + try: + workspace_ids = ( + self.db.query(Workspace.id) + .filter(Workspace.is_active) + .all() + ) + # 提取ID(查询返回的是元组列表) + workspace_id_list = [workspace_id[0] for workspace_id in workspace_ids] + db_logger.info(f"成功查询到 {len(workspace_id_list)} 个活动工作空间") + return workspace_id_list + except Exception as e: + self.db.rollback() + db_logger.error(f"查询活动工作空间时出错: {str(e)}") + raise + def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: """根据应用ID查询宿主(返回 EndUser ORM 列表)""" repo = EndUserRepository(db) @@ -138,4 +292,30 @@ def update_end_user_other_name( except Exception as e: db.rollback() db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}") - raise \ No newline at end of file + raise + +# 新增的缓存操作函数(保持与类方法一致的接口) +def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: + """根据ID获取终端用户(用于缓存操作)""" + repo = EndUserRepository(db) + return repo.get_by_id(end_user_id) + +def update_memory_insight(db: Session, end_user_id: uuid.UUID, insight: str) -> bool: + """更新记忆洞察缓存""" + repo = EndUserRepository(db) + return repo.update_memory_insight(end_user_id, insight) + +def update_user_summary(db: Session, end_user_id: uuid.UUID, summary: str) -> bool: + """更新用户摘要缓存""" + repo = EndUserRepository(db) + return repo.update_user_summary(end_user_id, summary) + +def get_all_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]: + """获取工作空间的所有终端用户""" + repo = EndUserRepository(db) + return repo.get_all_by_workspace(workspace_id) + +def get_all_active_workspaces(db: Session) -> List[uuid.UUID]: + """获取所有活动工作空间的ID""" + repo = EndUserRepository(db) + return repo.get_all_active_workspaces() diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 74fc4a14..939d2d3e 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -16,3 +16,37 @@ class EndUser(BaseModel): reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now) created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) + + # 用户基本信息字段 + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) + updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None) + + +class EndUserProfileResponse(BaseModel): + """终端用户基本信息响应模型""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID = Field(description="终端用户ID") + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) + updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None) + + +class EndUserProfileUpdate(BaseModel): + """终端用户基本信息更新请求模型""" + end_user_id: str = Field(description="终端用户ID") + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) \ No newline at end of file diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index be249b5e..df70ec77 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -382,3 +382,12 @@ def fail( error=error_code, time=time or _now_ms(), ) + +class GenerateCacheRequest(BaseModel): + """缓存生成请求模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + end_user_id: Optional[str] = Field( + None, + description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" + ) diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index be4ec12f..6acc699a 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -268,10 +268,20 @@ async def get_workspace_total_memory_count( # 如果提供了 end_user_id,只查询该用户 if end_user_id: search_result = await memory_storage_service.search_all(end_user_id=end_user_id) + # 查询用户名称 + from app.repositories.end_user_repository import EndUserRepository + repo = EndUserRepository(db) + end_user = repo.get_by_id(uuid.UUID(end_user_id)) + user_name = end_user.name if end_user else None + return { "total_memory_count": search_result.get("total", 0), "host_count": 1, - "details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}] + "details": [{ + "end_user_id": end_user_id, + "count": search_result.get("total", 0), + "name": user_name + }] } for host in hosts: @@ -287,17 +297,19 @@ async def get_workspace_total_memory_count( details.append({ "end_user_id": end_user_id_str, - "count": host_total + "count": host_total, + "name": host.name # 添加 name 字段 }) - business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}") + business_logger.debug(f"EndUser {end_user_id_str} ({host.name}) 记忆数: {host_total}") except Exception as e: business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}") # 失败的 host 记为 0 details.append({ "end_user_id": str(host.id), - "count": 0 + "count": 0, + "name": host.name # 添加 name 字段 }) result = { diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 0548b704..2644cd8d 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -15,11 +15,9 @@ from sqlalchemy.orm import Session from dotenv import load_dotenv from app.models.user_model import User -from app.models.end_user_model import EndUser from app.core.logging_config import get_logger from app.utils.sse_utils import format_sse_message from app.schemas.memory_storage_schema import ( - ConfigFilter, ConfigPilotRun, ConfigParamsCreate, ConfigParamsDelete, @@ -34,7 +32,8 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.memory_insight import MemoryInsight from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats from app.core.memory.analytics.user_summary import generate_user_summary -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.end_user_repository import EndUserRepository +import uuid logger = get_logger(__name__) @@ -67,6 +66,7 @@ class MemoryStorageService: } return result + class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. @@ -85,7 +85,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) @staticmethod def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式""" - from datetime import datetime for item in data_list: for field in ['created_at', 'updated_at']: @@ -576,14 +575,6 @@ async def analytics_hot_memory_tags( return [{"name": t, "frequency": f} for t, f in top_tags] -async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]: - insight = MemoryInsight(end_user_id) - report = await insight.generate_insight_report() - await insight.close() - data = {"report": report} - return data - - async def analytics_recent_activity_stats() -> Dict[str, Any]: stats, _msg = get_recent_activity_stats() total = ( @@ -617,8 +608,3 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: data = {"total": total, "stats": stats, "latest_relative": latest_relative} return data - -async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]: - summary = await generate_user_summary(end_user_id) - data = {"summary": summary} - return data \ No newline at end of file diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py new file mode 100644 index 00000000..a69c776e --- /dev/null +++ b/api/app/services/user_memory_service.py @@ -0,0 +1,831 @@ +""" +User Memory Service + +处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。 +""" + +from typing import Dict, List, Optional, Any +import uuid +from sqlalchemy.orm import Session + +from app.core.logging_config import get_logger +from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.analytics.memory_insight import MemoryInsight +from app.core.memory.analytics.user_summary import generate_user_summary + +logger = get_logger(__name__) + +# Neo4j connector instance +_neo4j_connector = Neo4jConnector() + + +class UserMemoryService: + """用户记忆服务类""" + + def __init__(self): + logger.info("UserMemoryService initialized") + + async def get_cached_memory_insight( + self, + db: Session, + end_user_id: str + ) -> Dict[str, Any]: + """ + 从数据库获取缓存的记忆洞察 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + + Returns: + { + "report": str, + "updated_at": datetime, + "is_cached": bool + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "用户不存在" + } + + # 检查是否有缓存数据 + if end_user.memory_insight: + logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察") + return { + "report": end_user.memory_insight, + "updated_at": end_user.memory_insight_updated_at, + "is_cached": True + } + else: + logger.info(f"end_user_id {end_user_id} 的记忆洞察缓存为空") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "数据尚未生成,请稍后重试或联系管理员" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取缓存记忆洞察时出错: {str(e)}") + raise + + async def get_cached_user_summary( + self, + db: Session, + end_user_id: str + ) -> Dict[str, Any]: + """ + 从数据库获取缓存的用户摘要 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + + Returns: + { + "summary": str, + "updated_at": datetime, + "is_cached": bool + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "用户不存在" + } + + # 检查是否有缓存数据 + if end_user.user_summary: + logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要") + return { + "summary": end_user.user_summary, + "updated_at": end_user.user_summary_updated_at, + "is_cached": True + } + else: + logger.info(f"end_user_id {end_user_id} 的用户摘要缓存为空") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "数据尚未生成,请稍后重试或联系管理员" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取缓存用户摘要时出错: {str(e)}") + raise + + async def generate_and_cache_insight( + self, + db: Session, + end_user_id: str, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """ + 生成并缓存记忆洞察 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + workspace_id: 工作空间ID (可选) + + Returns: + { + "success": bool, + "report": str, + "error": Optional[str] + } + """ + try: + logger.info(f"开始为 end_user_id {end_user_id} 生成记忆洞察") + + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.error(f"end_user_id {end_user_id} 不存在") + return { + "success": False, + "report": None, + "error": "用户不存在" + } + + # 使用 end_user_id 调用分析函数 + try: + logger.info(f"使用 end_user_id={end_user_id} 生成记忆洞察") + result = await analytics_memory_insight_report(end_user_id) + report = result.get("report", "") + + if not report: + logger.warning(f"end_user_id {end_user_id} 的记忆洞察生成结果为空") + return { + "success": False, + "report": None, + "error": "生成的洞察报告为空,可能Neo4j中没有该用户的数据" + } + + # 更新数据库缓存 + success = repo.update_memory_insight(user_uuid, report) + + if success: + logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存记忆洞察") + return { + "success": True, + "report": report, + "error": None + } + else: + logger.error(f"更新 end_user_id {end_user_id} 的记忆洞察缓存失败") + return { + "success": False, + "report": report, + "error": "数据库更新失败" + } + + except Exception as e: + logger.error(f"调用分析函数生成记忆洞察时出错: {str(e)}") + return { + "success": False, + "report": None, + "error": f"Neo4j或LLM服务不可用: {str(e)}" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "report": None, + "error": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"生成并缓存记忆洞察时出错: {str(e)}") + return { + "success": False, + "report": None, + "error": str(e) + } + + async def generate_and_cache_summary( + self, + db: Session, + end_user_id: str, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """ + 生成并缓存用户摘要 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + workspace_id: 工作空间ID (可选) + + Returns: + { + "success": bool, + "summary": str, + "error": Optional[str] + } + """ + try: + logger.info(f"开始为 end_user_id {end_user_id} 生成用户摘要") + + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.error(f"end_user_id {end_user_id} 不存在") + return { + "success": False, + "summary": None, + "error": "用户不存在" + } + + # 使用 end_user_id 调用分析函数 + try: + logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要") + result = await analytics_user_summary(end_user_id) + summary = result.get("summary", "") + + if not summary: + logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空") + return { + "success": False, + "summary": None, + "error": "生成的用户摘要为空,可能Neo4j中没有该用户的数据" + } + + # 更新数据库缓存 + success = repo.update_user_summary(user_uuid, summary) + + if success: + logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存用户摘要") + return { + "success": True, + "summary": summary, + "error": None + } + else: + logger.error(f"更新 end_user_id {end_user_id} 的用户摘要缓存失败") + return { + "success": False, + "summary": summary, + "error": "数据库更新失败" + } + + except Exception as e: + logger.error(f"调用分析函数生成用户摘要时出错: {str(e)}") + return { + "success": False, + "summary": None, + "error": f"Neo4j或LLM服务不可用: {str(e)}" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "summary": None, + "error": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"生成并缓存用户摘要时出错: {str(e)}") + return { + "success": False, + "summary": None, + "error": str(e) + } + + async def generate_cache_for_workspace( + self, + db: Session, + workspace_id: uuid.UUID + ) -> Dict[str, Any]: + """ + 为整个工作空间生成缓存 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + + Returns: + { + "total_users": int, + "successful": int, + "failed": int, + "errors": List[Dict] + } + """ + logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存") + + total_users = 0 + successful = 0 + failed = 0 + errors = [] + + try: + # 获取工作空间的所有终端用户 + repo = EndUserRepository(db) + end_users = repo.get_all_by_workspace(workspace_id) + total_users = len(end_users) + + logger.info(f"工作空间 {workspace_id} 共有 {total_users} 个终端用户") + + # 遍历每个用户并生成缓存 + for end_user in end_users: + end_user_id = str(end_user.id) + + try: + # 生成记忆洞察 + insight_result = await self.generate_and_cache_insight(db, end_user_id) + + # 生成用户摘要 + summary_result = await self.generate_and_cache_summary(db, end_user_id) + + # 检查是否都成功 + if insight_result["success"] and summary_result["success"]: + successful += 1 + logger.info(f"成功为终端用户 {end_user_id} 生成缓存") + else: + failed += 1 + error_info = { + "end_user_id": end_user_id, + "insight_error": insight_result.get("error"), + "summary_error": summary_result.get("error") + } + errors.append(error_info) + logger.warning(f"终端用户 {end_user_id} 的缓存生成部分失败: {error_info}") + + except Exception as e: + # 单个用户失败不影响其他用户 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "error": str(e) + } + errors.append(error_info) + logger.error(f"为终端用户 {end_user_id} 生成缓存时出错: {str(e)}") + + # 记录统计信息 + logger.info( + f"工作空间 {workspace_id} 批量生成完成: " + f"总数={total_users}, 成功={successful}, 失败={failed}" + ) + + return { + "total_users": total_users, + "successful": successful, + "failed": failed, + "errors": errors + } + + except Exception as e: + logger.error(f"为工作空间 {workspace_id} 批量生成缓存时出错: {str(e)}") + return { + "total_users": total_users, + "successful": successful, + "failed": failed, + "errors": errors + [{"error": f"批量处理失败: {str(e)}"}] + } + + +# 独立的分析函数 + +async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]: + """ + 生成记忆洞察报告 + + Args: + end_user_id: 可选的终端用户ID + + Returns: + 包含报告的字典 + """ + insight = MemoryInsight(end_user_id) + report = await insight.generate_insight_report() + await insight.close() + data = {"report": report} + return data + + +async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]: + """ + 生成用户摘要 + + Args: + end_user_id: 可选的终端用户ID + + Returns: + 包含摘要的字典 + """ + summary = await generate_user_summary(end_user_id) + data = {"summary": summary} + return data + + +async def analytics_node_statistics( + db: Session, + end_user_id: Optional[str] = None +) -> Dict[str, Any]: + """ + 统计 Neo4j 中四种节点类型的数量和百分比 + + Args: + db: 数据库会话 + end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点 + + Returns: + { + "total": int, # 总节点数 + "nodes": [ + { + "type": str, # 节点类型 + "count": int, # 节点数量 + "percentage": float # 百分比 + } + ] + } + """ + # 定义四种节点类型的查询 + node_types = ["Chunk", "MemorySummary", "Statement", "ExtractedEntity"] + + # 存储每种节点类型的计数 + node_counts = {} + + # 查询每种节点类型的数量 + for node_type in node_types: + # 构建查询语句 + if end_user_id: + query = f""" + MATCH (n:{node_type}) + WHERE n.group_id = $group_id + RETURN count(n) as count + """ + result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + else: + query = f""" + MATCH (n:{node_type}) + RETURN count(n) as count + """ + result = await _neo4j_connector.execute_query(query) + + # 提取计数结果 + count = result[0]["count"] if result and len(result) > 0 else 0 + node_counts[node_type] = count + + # 计算总数 + total = sum(node_counts.values()) + + # 构建返回数据,包含百分比 + nodes = [] + for node_type in node_types: + count = node_counts[node_type] + percentage = round((count / total * 100), 2) if total > 0 else 0.0 + nodes.append({ + "type": node_type, + "count": count, + "percentage": percentage + }) + + data = { + "total": total, + "nodes": nodes + } + + return data + + +async def analytics_graph_data( + db: Session, + end_user_id: str, + node_types: Optional[List[str]] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None +) -> Dict[str, Any]: + """ + 获取 Neo4j 图数据,用于前端可视化 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + node_types: 可选的节点类型列表 + limit: 返回节点数量限制 + depth: 图遍历深度 + center_node_id: 可选的中心节点ID + + Returns: + 包含节点、边和统计信息的字典 + """ + try: + # 1. 获取 group_id + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "nodes": [], + "edges": [], + "statistics": { + "total_nodes": 0, + "total_edges": 0, + "node_types": {}, + "edge_types": {} + }, + "message": "用户不存在" + } + + # 2. 构建节点查询 + if center_node_id: + # 基于中心节点的扩展查询 + node_query = f""" + MATCH path = (center)-[*1..{depth}]-(connected) + WHERE center.group_id = $group_id + AND elementId(center) = $center_node_id + WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes + UNWIND all_nodes as n + RETURN DISTINCT + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "center_node_id": center_node_id, + "limit": limit + } + elif node_types: + # 按节点类型过滤查询 + node_query = """ + MATCH (n) + WHERE n.group_id = $group_id + AND labels(n)[0] IN $node_types + RETURN + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "node_types": node_types, + "limit": limit + } + else: + # 查询所有节点 + node_query = """ + MATCH (n) + WHERE n.group_id = $group_id + RETURN + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "limit": limit + } + + # 执行节点查询 + node_results = await _neo4j_connector.execute_query(node_query, **node_params) + + # 3. 格式化节点数据 + nodes = [] + node_ids = [] + node_type_counts = {} + + for record in node_results: + node_id = record["id"] + node_label = record["label"] + node_props = record["properties"] + + # 根据节点类型提取需要的属性字段 + filtered_props = _extract_node_properties(node_label, node_props) + + # 直接使用数据库中的 caption,如果没有则使用节点类型作为默认值 + caption = filtered_props.get("caption", node_label) + + nodes.append({ + "id": node_id, + "label": node_label, + "properties": filtered_props, + "caption": caption + }) + + node_ids.append(node_id) + node_type_counts[node_label] = node_type_counts.get(node_label, 0) + 1 + + # 4. 查询节点之间的关系 + if len(node_ids) > 0: + edge_query = """ + MATCH (n)-[r]->(m) + WHERE elementId(n) IN $node_ids + AND elementId(m) IN $node_ids + RETURN + elementId(r) as id, + elementId(n) as source, + elementId(m) as target, + type(r) as rel_type, + properties(r) as properties + """ + edge_results = await _neo4j_connector.execute_query( + edge_query, + node_ids=node_ids + ) + else: + edge_results = [] + + # 5. 格式化边数据 + edges = [] + edge_type_counts = {} + + for record in edge_results: + edge_id = record["id"] + source = record["source"] + target = record["target"] + rel_type = record["rel_type"] + edge_props = record["properties"] + + # 清理边属性中的 Neo4j 特殊类型 + # 对于边,我们保留所有属性,但清理特殊类型 + cleaned_edge_props = {} + if edge_props: + for key, value in edge_props.items(): + cleaned_edge_props[key] = _clean_neo4j_value(value) + + # 直接使用关系类型作为 caption,如果 properties 中有 caption 则使用它 + caption = cleaned_edge_props.get("caption", rel_type) + + edges.append({ + "id": edge_id, + "source": source, + "target": target, + "type": rel_type, + "properties": cleaned_edge_props, + "caption": caption + }) + + edge_type_counts[rel_type] = edge_type_counts.get(rel_type, 0) + 1 + + # 6. 构建统计信息 + statistics = { + "total_nodes": len(nodes), + "total_edges": len(edges), + "node_types": node_type_counts, + "edge_types": edge_type_counts + } + + logger.info( + f"成功获取图数据: end_user_id={end_user_id}, " + f"nodes={len(nodes)}, edges={len(edges)}" + ) + + return { + "nodes": nodes, + "edges": edges, + "statistics": statistics + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "nodes": [], + "edges": [], + "statistics": { + "total_nodes": 0, + "total_edges": 0, + "node_types": {}, + "edge_types": {} + }, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取图数据失败: {str(e)}", exc_info=True) + raise + + +# 辅助函数 + +def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据节点类型提取需要的属性字段 + + Args: + label: 节点类型标签 + properties: 节点的所有属性 + + Returns: + 过滤后的属性字典 + """ + # 定义每种节点类型需要的字段(白名单) + field_whitelist = { + "Dialogue": ["content", "created_at"], + "Chunk": ["content", "created_at"], + "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], + "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], + "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + } + + # 获取该节点类型的白名单字段 + allowed_fields = field_whitelist.get(label, []) + + # 如果没有定义白名单,返回空字典(或者可以返回所有字段) + if not allowed_fields: + # 对于未定义的节点类型,只返回基本字段 + allowed_fields = ["name", "created_at", "caption"] + + # 提取白名单中的字段 + filtered_props = {} + for field in allowed_fields: + if field in properties: + value = properties[field] + # 清理 Neo4j 特殊类型 + filtered_props[field] = _clean_neo4j_value(value) + + return filtered_props + + +def _clean_neo4j_value(value: Any) -> Any: + """ + 清理单个值的 Neo4j 特殊类型 + + Args: + value: 需要清理的值 + + Returns: + 清理后的值 + """ + if value is None: + return None + + # 处理列表 + if isinstance(value, list): + return [_clean_neo4j_value(item) for item in value] + + # 处理字典 + if isinstance(value, dict): + return {k: _clean_neo4j_value(v) for k, v in value.items()} + + # 处理 Neo4j DateTime 类型 + if hasattr(value, '__class__') and 'neo4j.time' in str(type(value)): + try: + if hasattr(value, 'to_native'): + native_dt = value.to_native() + return native_dt.isoformat() + return str(value) + except Exception: + return str(value) + + # 处理其他 Neo4j 特殊类型 + if hasattr(value, '__class__') and 'neo4j' in str(type(value)): + try: + return str(value) + except Exception: + return None + + # 返回原始值 + return value diff --git a/api/app/tasks.py b/api/app/tasks.py index 39758275..55d6680c 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,15 +1,13 @@ -import os import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import requests from datetime import datetime, timezone import time import uuid from math import ceil import redis -import json -from app.db import get_db +from app.db import get_db_context from app.models.document_model import Document from app.models.knowledge_model import Knowledge from app.core.rag.llm.cv_model import QWenCV @@ -48,124 +46,122 @@ def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ - db = next(get_db()) # Manually call the generator - db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: - db_document = db.query(Document).filter(Document.id == document_id).first() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() - # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" - start_time = time.time() - db_document.progress = 0.0 - db_document.progress_msg = progress_msg - db_document.process_begin_at = datetime.now(tz=timezone.utc) - db_document.process_duration = 0.0 - db_document.run = 1 - db.commit() - db.refresh(db_document) - - def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" - # Prepare to configure chat_mdl、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - from app.core.rag.app.naive import chunk - res = chunk(filename=file_path, - from_page=0, - to_page=100000, - callback=progress_callback, - vision_model=vision_model, - parser_config=db_document.parser_config, - is_root=False) - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" - db_document.progress = 0.8 - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) - - # 2. Document vectorization and storage - total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] - - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if db_document.parser_config.get("auto_questions", 0): - topn = db_document.parser_config["auto_questions"] - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn}) - chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata)) - else: - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - - # Bulk segmented vector import - vector_service.add_chunks(chunks) - - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" + with get_db_context() as db: + db_document = None + db_knowledge = None + progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" + try: + db_document = db.query(Document).filter(Document.id == document_id).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() + # 1. Document parsing & segmentation + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" + start_time = time.time() + db_document.progress = 0.0 db_document.progress_msg = progress_msg - db_document.process_duration = time.time() - start_time - db_document.run = 0 + db_document.process_begin_at = datetime.now(tz=timezone.utc) + db_document.process_duration = 0.0 + db_document.run = 1 db.commit() db.refresh(db_document) - # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" - db_document.chunk_num = total_chunks - db_document.progress = 1.0 - db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' processed successfully." - return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + def progress_callback(prog=None, msg=None): + nonlocal progress_msg # Declare the use of an external progress_msg variable + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" + # Prepare to configure chat_mdl、vision_model information + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base + ) + vision_model = QWenCV( + key=db_knowledge.image2text.api_keys[0].api_key, + model_name=db_knowledge.image2text.api_keys[0].model_name, + lang="Chinese", + base_url=db_knowledge.image2text.api_keys[0].api_base + ) + from app.core.rag.app.naive import chunk + res = chunk(filename=file_path, + from_page=0, + to_page=100000, + callback=progress_callback, + vision_model=vision_model, + parser_config=db_document.parser_config, + is_root=False) + + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" + db_document.progress = 0.8 + db_document.progress_msg = progress_msg + db.commit() + db.refresh(db_document) + + # 2. Document vectorization and storage + total_chunks = len(res) + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" + batch_size = 100 + total_batches = ceil(total_chunks / batch_size) + progress_per_batch = 0.2 / total_batches # Progress of each batch + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2.1 Delete document vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) + # 2.2 Vectorize and import batch documents + for batch_start in range(0, total_chunks, batch_size): + batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds + batch = res[batch_start: batch_end] # Retrieve the current batch + chunks = [] + + # Process the current batch + for idx_in_batch, item in enumerate(batch): + global_idx = batch_start + idx_in_batch # Calculate global index + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + if db_document.parser_config.get("auto_questions", 0): + topn = db_document.parser_config["auto_questions"] + cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn}) + if not cached: + cached = question_proposal(chat_model, item["content_with_weight"], topn) + set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn}) + chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata)) + else: + chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + + # Bulk segmented vector import + vector_service.add_chunks(chunks) + + # Update progress + db_document.progress += progress_per_batch + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" + db_document.progress_msg = progress_msg + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) + + # Vectorization and data entry completed + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" + db_document.chunk_num = total_chunks + db_document.progress = 1.0 + db_document.process_duration = time.time() - start_time + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" + db_document.progress_msg = progress_msg db_document.run = 0 db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() + result = f"parse document '{db_document.file_name}' processed successfully." + return result + except Exception as e: + if 'db_document' in locals(): + db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + db_document.run = 0 + db.commit() + result = f"parse document '{db_document.file_name}' failed." + return result @celery_app.task(name="app.core.memory.agent.read_message", bind=True) @@ -362,75 +358,75 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: from app.models.end_user_model import EndUser from app.models.app_model import App - db = next(get_db()) - try: - workspace_uuid = uuid.UUID(workspace_id) - - # 1. 查询当前workspace下的所有app - apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() - - if not apps: - # 如果没有app,总量为0 + with get_db_context() as db: + try: + workspace_uuid = uuid.UUID(workspace_id) + + # 1. 查询当前workspace下的所有app + apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() + + if not apps: + # 如果没有app,总量为0 + memory_increment = write_memory_increment( + db=db, + workspace_id=workspace_uuid, + total_num=0 + ) + return { + "status": "SUCCESS", + "workspace_id": workspace_id, + "total_num": 0, + "end_user_count": 0, + "memory_increment_id": str(memory_increment.id), + "created_at": memory_increment.created_at.isoformat(), + } + + # 2. 查询所有app下的end_user_id(去重) + app_ids = [app.id for app in apps] + end_users = db.query(EndUser.id).filter( + EndUser.app_id.in_(app_ids) + ).distinct().all() + + # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 + total_num = 0 + end_user_details = [] + + for (end_user_id,) in end_users: + try: + # 调用 search_all 接口查询该宿主的总量 + result = await search_all(str(end_user_id)) + user_total = result.get("total", 0) + total_num += user_total + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": user_total + }) + except Exception as e: + # 记录单个用户查询失败,但继续处理其他用户 + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": 0, + "error": str(e) + }) + + # 4. 写入数据库 memory_increment = write_memory_increment( db=db, workspace_id=workspace_uuid, - total_num=0 + total_num=total_num ) + return { "status": "SUCCESS", "workspace_id": workspace_id, - "total_num": 0, - "end_user_count": 0, + "total_num": total_num, + "end_user_count": len(end_users), + "end_user_details": end_user_details, "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), } - - # 2. 查询所有app下的end_user_id(去重) - app_ids = [app.id for app in apps] - end_users = db.query(EndUser.id).filter( - EndUser.app_id.in_(app_ids) - ).distinct().all() - - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] - - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) - - # 4. 写入数据库 - memory_increment = write_memory_increment( - db=db, - workspace_id=workspace_uuid, - total_num=total_num - ) - - return { - "status": "SUCCESS", - "workspace_id": workspace_id, - "total_num": total_num, - "end_user_count": len(end_users), - "end_user_details": end_user_details, - "memory_increment_id": str(memory_increment.id), - "created_at": memory_increment.created_at.isoformat(), - } - finally: - db.close() + except Exception as e: + raise e try: result = asyncio.run(_run()) @@ -447,6 +443,198 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } +@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True) +def regenerate_memory_cache(self) -> Dict[str, Any]: + """定时任务:为所有用户重新生成记忆洞察和用户摘要缓存 + + 遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。 + 实现错误隔离,单个用户失败不影响其他用户的处理。 + + Returns: + 包含任务执行结果的字典,包括: + - status: 任务状态 (SUCCESS/FAILURE) + - message: 执行消息 + - workspace_count: 处理的工作空间数量 + - total_users: 总用户数 + - successful: 成功生成的用户数 + - failed: 失败的用户数 + - workspace_results: 每个工作空间的详细结果 + - elapsed_time: 执行耗时(秒) + - task_id: 任务ID + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.services.user_memory_service import UserMemoryService + from app.repositories.end_user_repository import EndUserRepository + from app.core.logging_config import get_logger + + logger = get_logger(__name__) + logger.info("开始执行记忆缓存重新生成定时任务") + + service = UserMemoryService() + + total_users = 0 + successful = 0 + failed = 0 + workspace_results = [] + + with get_db_context() as db: + try: + # 获取所有活动工作空间 + repo = EndUserRepository(db) + workspaces = repo.get_all_active_workspaces() + logger.info(f"找到 {len(workspaces)} 个活动工作空间") + + # 遍历每个工作空间 + for workspace_id in workspaces: + logger.info(f"开始处理工作空间: {workspace_id}") + workspace_start_time = time.time() + + try: + # 获取工作空间的所有终端用户 + end_users = repo.get_all_by_workspace(workspace_id) + workspace_user_count = len(end_users) + total_users += workspace_user_count + + logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户") + + workspace_successful = 0 + workspace_failed = 0 + workspace_errors = [] + + # 遍历每个用户并生成缓存 + for end_user in end_users: + end_user_id = str(end_user.id) + + try: + # 生成记忆洞察 + insight_result = await service.generate_and_cache_insight(db, end_user_id) + + # 生成用户摘要 + summary_result = await service.generate_and_cache_summary(db, end_user_id) + + # 检查是否都成功 + if insight_result["success"] and summary_result["success"]: + workspace_successful += 1 + successful += 1 + logger.info(f"成功为终端用户 {end_user_id} 重新生成缓存") + else: + workspace_failed += 1 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "insight_error": insight_result.get("error"), + "summary_error": summary_result.get("error") + } + workspace_errors.append(error_info) + logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}") + + except Exception as e: + # 单个用户失败不影响其他用户(错误隔离) + workspace_failed += 1 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "error": str(e) + } + workspace_errors.append(error_info) + logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}") + + workspace_elapsed = time.time() - workspace_start_time + + # 记录工作空间处理结果 + workspace_result = { + "workspace_id": str(workspace_id), + "total_users": workspace_user_count, + "successful": workspace_successful, + "failed": workspace_failed, + "errors": workspace_errors[:10], # 只保留前10个错误 + "elapsed_time": workspace_elapsed + } + workspace_results.append(workspace_result) + + logger.info( + f"工作空间 {workspace_id} 处理完成: " + f"总数={workspace_user_count}, 成功={workspace_successful}, " + f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒" + ) + + except Exception as e: + # 工作空间处理失败,记录错误并继续处理下一个 + logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}") + workspace_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "total_users": 0, + "successful": 0, + "failed": 0, + "errors": [] + }) + + # 记录总体统计信息 + logger.info( + f"记忆缓存重新生成定时任务完成: " + f"工作空间数={len(workspaces)}, 总用户数={total_users}, " + f"成功={successful}, 失败={failed}" + ) + + return { + "status": "SUCCESS", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功", + "workspace_count": len(workspaces), + "total_users": total_users, + "successful": successful, + "failed": failed, + "workspace_results": workspace_results + } + + except Exception as e: + logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "workspace_count": len(workspace_results), + "total_users": total_users, + "successful": successful, + "failed": failed, + "workspace_results": workspace_results + } + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + @celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) def workspace_reflection_task(self) -> Dict[str, Any]: """定时任务:每30秒运行工作空间反思功能 @@ -462,100 +650,98 @@ def workspace_reflection_task(self) -> Dict[str, Any]: from app.core.logging_config import get_api_logger api_logger = get_api_logger() - db = next(get_db()) + + with get_db_context() as db: + try: + # 获取所有工作空间 + workspaces = db.query(Workspace).all() - try: - # 获取所有工作空间 - workspaces = db.query(Workspace).all() + if not workspaces: + return { + "status": "SUCCESS", + "message": "没有找到工作空间", + "workspace_count": 0, + "reflection_results": [] + } + + all_reflection_results = [] + + # 遍历每个工作空间 + for workspace in workspaces: + workspace_id = workspace.id + api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + + try: + reflection_service = MemoryReflectionService(db) + + # 使用服务类处理复杂查询逻辑 + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(str(workspace_id)) + + workspace_reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + workspace_reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "reflection_count": len(workspace_reflection_results), + "reflection_results": workspace_reflection_results + }) + + api_logger.info( + f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") + + except Exception as e: + api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "reflection_count": 0, + "reflection_results": [] + }) + + total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) - if not workspaces: return { "status": "SUCCESS", - "message": "没有找到工作空间", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", + "workspace_count": len(workspaces), + "total_reflections": total_reflections, + "workspace_results": all_reflection_results + } + + except Exception as e: + api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), "workspace_count": 0, "reflection_results": [] } - all_reflection_results = [] - - # 遍历每个工作空间 - for workspace in workspaces: - workspace_id = workspace.id - api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") - - try: - reflection_service = MemoryReflectionService(db) - - # 使用服务类处理复杂查询逻辑 - service = WorkspaceAppService(db) - result = service.get_workspace_apps_detailed(str(workspace_id)) - - workspace_reflection_results = [] - - for data in result['apps_detailed_info']: - if data['data_configs'] == []: - continue - - releases = data['releases'] - data_configs = data['data_configs'] - end_users = data['end_users'] - - for base, config, user in zip(releases, data_configs, end_users): - if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: - # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - - reflection_result = await reflection_service.start_reflection_from_data( - config_data=config, - end_user_id=user['id'] - ) - - workspace_reflection_results.append({ - "app_id": base['app_id'], - "config_id": config['config_id'], - "end_user_id": user['id'], - "reflection_result": reflection_result - }) - - all_reflection_results.append({ - "workspace_id": str(workspace_id), - "reflection_count": len(workspace_reflection_results), - "reflection_results": workspace_reflection_results - }) - - api_logger.info( - f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") - - except Exception as e: - api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") - all_reflection_results.append({ - "workspace_id": str(workspace_id), - "error": str(e), - "reflection_count": 0, - "reflection_results": [] - }) - - total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) - - return { - "status": "SUCCESS", - "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", - "workspace_count": len(workspaces), - "total_reflections": total_reflections, - "workspace_results": all_reflection_results - } - - except Exception as e: - api_logger.error(f"工作空间反思任务执行失败: {str(e)}") - return { - "status": "FAILURE", - "error": str(e), - "workspace_count": 0, - "reflection_results": [] - } - finally: - db.close() - try: # 使用 nest_asyncio 来避免事件循环冲突 try: diff --git a/api/env.example b/api/env.example index c4e0c1eb..1354233d 100644 --- a/api/env.example +++ b/api/env.example @@ -30,6 +30,11 @@ RESULT_BACKEND= CELERY_BROKER= CELERY_BACKEND= +# Memory Cache Regeneration Configuration +# Interval in hours for regenerating memory insight and user summary cache +# Default: 24 hours +MEMORY_CACHE_REGENERATION_HOURS=24 + # ElasticSearch configuration ELASTICSEARCH_HOST= ELASTICSEARCH_PORT= From 47c4f93014921a61bd89b4509fecdc3282368f7f Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 23 Dec 2025 16:38:25 +0800 Subject: [PATCH 07/15] [fix] end node bug --- api/app/core/workflow/nodes/end/node.py | 45 +++++++++++++------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6e108e8d..2e5758eb 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -168,6 +168,8 @@ class EndNode(BaseNode): # 解析模板部分 parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") + for i, part in enumerate(parts): + logger.info(f"[模板解析] part[{i}]: {part}") # 找到第一个引用直接上游节点的动态引用 upstream_ref_index = None @@ -204,38 +206,32 @@ class EndNode(BaseNode): # 收集后缀部分 suffix_parts = [] + logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") for i in range(upstream_ref_index + 1, len(parts)): part = parts[i] - + logger.info(f"[后缀调试] 处理 part[{i}]: {part}") if part["type"] == "static": # 静态文本 + logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") suffix_parts.append(part["content"]) elif part["type"] == "dynamic": - # 其他动态引用(如果有多个引用) + # Other dynamic references (if there are multiple references) node_id = part["node_id"] field = part["field"] - # 从 streaming_buffer 或 node_outputs 读取 - streaming_buffer = state.get("streaming_buffer", {}) - if node_id in streaming_buffer: - buffer_data = streaming_buffer[node_id] - content = buffer_data.get("full_content", "") - else: - node_outputs = state.get("node_outputs", {}) - runtime_vars = state.get("runtime_vars", {}) - + # Use VariablePool to get variable value + pool = self.get_variable_pool(state) + try: + # Try to get variable value with default empty string + content = pool.get([node_id, field], default="") + logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'") + except Exception as e: + logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") content = "" - if node_id in node_outputs: - node_output = node_outputs[node_id] - if isinstance(node_output, dict): - content = str(node_output.get(field, "")) - elif node_id in runtime_vars: - runtime_var = runtime_vars[node_id] - if isinstance(runtime_var, dict): - content = str(runtime_var.get(field, "")) - suffix_parts.append(content) + # Convert to string if not None + suffix_parts.append(str(content) if content is not None else "") # 拼接后缀 suffix = "".join(suffix_parts) @@ -243,8 +239,13 @@ class EndNode(BaseNode): # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state) + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") + logger.info(f"[后缀调试] 后缀内容: '{suffix}'") + logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") + logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") + if suffix: - logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})") + logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") # 一次性输出后缀(作为单个 chunk) # 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理 # 而是通过 writer 直接发送 @@ -260,7 +261,7 @@ class EndNode(BaseNode): }) logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") else: - logger.info(f"节点 {self.node_id} 没有后缀需要输出") + logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}") # 统计信息 node_outputs = state.get("node_outputs", {}) From 00f440f471f412e09efcc06a99cf5607f7343557 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 23 Dec 2025 16:49:30 +0800 Subject: [PATCH 08/15] [modify] migration script --- .../versions/f3d893ccb866_202512231644.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 api/migrations/versions/f3d893ccb866_202512231644.py diff --git a/api/migrations/versions/f3d893ccb866_202512231644.py b/api/migrations/versions/f3d893ccb866_202512231644.py new file mode 100644 index 00000000..2dffdc33 --- /dev/null +++ b/api/migrations/versions/f3d893ccb866_202512231644.py @@ -0,0 +1,50 @@ +"""202512231644 + +Revision ID: f3d893ccb866 +Revises: 022550fdcfda +Create Date: 2025-12-23 16:47:30.897690 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f3d893ccb866' +down_revision: Union[str, None] = '022550fdcfda' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('end_users', sa.Column('name', sa.String(), nullable=True, comment='姓名')) + op.add_column('end_users', sa.Column('position', sa.String(), nullable=True, comment='职位')) + op.add_column('end_users', sa.Column('department', sa.String(), nullable=True, comment='部门')) + op.add_column('end_users', sa.Column('contact', sa.String(), nullable=True, comment='联系方式')) + op.add_column('end_users', sa.Column('phone', sa.String(), nullable=True, comment='电话')) + op.add_column('end_users', sa.Column('hire_date', sa.BigInteger(), nullable=True, comment='入职日期(时间戳,毫秒)')) + op.add_column('end_users', sa.Column('updatetime_profile', sa.BigInteger(), nullable=True, comment='核心档案信息最后更新时间(时间戳,毫秒)')) + op.add_column('end_users', sa.Column('memory_insight', sa.Text(), nullable=True, comment='缓存的记忆洞察报告')) + op.add_column('end_users', sa.Column('user_summary', sa.Text(), nullable=True, comment='缓存的用户摘要')) + op.add_column('end_users', sa.Column('memory_insight_updated_at', sa.DateTime(), nullable=True, comment='洞察报告最后更新时间')) + op.add_column('end_users', sa.Column('user_summary_updated_at', sa.DateTime(), nullable=True, comment='用户摘要最后更新时间')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('end_users', 'user_summary_updated_at') + op.drop_column('end_users', 'memory_insight_updated_at') + op.drop_column('end_users', 'user_summary') + op.drop_column('end_users', 'memory_insight') + op.drop_column('end_users', 'updatetime_profile') + op.drop_column('end_users', 'hire_date') + op.drop_column('end_users', 'phone') + op.drop_column('end_users', 'contact') + op.drop_column('end_users', 'department') + op.drop_column('end_users', 'position') + op.drop_column('end_users', 'name') + # ### end Alembic commands ### From c15a987701d4e740f9f104647fb3ff1c56d2075d Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 14:59:21 +0800 Subject: [PATCH 09/15] style(service): workflow --- api/app/schemas/app_schema.py | 87 +++++++++--------- api/app/services/workflow_service.py | 132 +++++++++++++-------------- 2 files changed, 110 insertions(+), 109 deletions(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index c387cee9..b6b1de52 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,7 @@ -import uuid import datetime -from typing import Optional, Any, List, Dict, TYPE_CHECKING +import uuid +from typing import Optional, Any, List, Dict + from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel): class KnowledgeRetrievalConfig(BaseModel): """知识库检索配置(支持多个知识库,每个有独立配置)""" knowledge_bases: List[KnowledgeBaseConfig] = Field( - default_factory=list, + default_factory=list, description="关联的知识库列表,每个知识库有独立配置" ) - + # 多知识库融合策略 merge_strategy: str = Field( - default="weighted", + default="weighted", description="多知识库结果融合策略: weighted | rrf | concat" ) reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") - class ToolConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -63,7 +63,7 @@ class VariableDefinition(BaseModel): name: str = Field(..., description="变量名称(标识符)") display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)") type: str = Field( - default="string", + default="string", description="变量类型: string(单行文本) | text(多行文本) | number(数字)" ) required: bool = Field(default=False, description="是否必填") @@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel): """Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID") model_parameters: ModelParameters = Field( default_factory=ModelParameters, description="模型参数配置(temperature、max_tokens 等)" ) - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: MemoryConfig = Field( default_factory=lambda: MemoryConfig(enabled=True), description="对话历史记忆配置" ) - + # 变量配置 variables: List[VariableDefinition] = Field( default_factory=list, description="Agent 可用的变量列表" ) - + # 工具配置 tools: Dict[str, ToolConfig] = Field( default_factory=dict, @@ -120,7 +120,7 @@ class AppCreate(BaseModel): # only for type=agent agent_config: Optional[AgentConfigCreate] = None - + # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None @@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel): """更新 Agent 行为配置""" # 提示词配置 system_prompt: Optional[str] = Field(default=None, description="系统提示词") - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置") - + # 知识库关联 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( default=None, description="知识库检索配置" ) - + # 记忆配置 memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置") - + # 变量配置 variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") - + # 工具配置 tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") @@ -185,7 +185,7 @@ class App(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -197,26 +197,26 @@ class AgentConfig(BaseModel): id: uuid.UUID app_id: uuid.UUID - + # 提示词 system_prompt: Optional[str] = None - + # 模型配置 default_model_config_id: Optional[uuid.UUID] = None model_parameters: ModelParameters = Field(default_factory=ModelParameters) - + # 知识库检索 knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None - + # 记忆配置 memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True)) - + # 变量配置 variables: List[VariableDefinition] = [] - + # 工具配置 tools: Dict[str, ToolConfig] = {} - + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -228,7 +228,7 @@ class AgentConfig(BaseModel): if v is None: return ModelParameters() return v - + @field_validator("memory", mode="before") @classmethod def validate_memory(cls, v): @@ -236,7 +236,7 @@ class AgentConfig(BaseModel): if v is None: return MemoryConfig(enabled=True) return v - + @field_validator("variables", mode="before") @classmethod def validate_variables(cls, v): @@ -244,7 +244,7 @@ class AgentConfig(BaseModel): if v is None: return [] return v - + @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v): @@ -256,7 +256,7 @@ class AgentConfig(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -294,15 +294,15 @@ class AppRelease(BaseModel): @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("published_at", when_used="json") def _serialize_published_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + # ---------- App Share Schemas ---------- @@ -314,7 +314,7 @@ class AppShareCreate(BaseModel): class AppShare(BaseModel): """应用分享输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID source_app_id: uuid.UUID source_workspace_id: uuid.UUID @@ -322,11 +322,11 @@ class AppShare(BaseModel): shared_by: uuid.UUID created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -338,6 +338,7 @@ class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") + conversation_vars: Optional[dict[str, Any]] = Field(default=None, description="会话变量") user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") @@ -382,14 +383,14 @@ class DraftRunCompareRequest(BaseModel): conversation_id: Optional[str] = Field(None, description="会话ID") user_id: Optional[str] = Field(None, description="用户ID") variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - + models: List[ModelCompareItem] = Field( ..., min_length=1, max_length=5, description="要对比的模型列表(1-5个)" ) - + parallel: bool = Field(True, description="是否并行执行") stream: bool = Field(False, description="是否流式返回") timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") @@ -400,14 +401,14 @@ class ModelRunResult(BaseModel): model_config_id: uuid.UUID model_name: str label: Optional[str] = None - + parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数") - + message: Optional[str] = None usage: Optional[Dict[str, Any]] = None elapsed_time: float error: Optional[str] = None - + tokens_per_second: Optional[float] = None cost_estimate: Optional[float] = None conversation_id: Optional[str] = None @@ -416,10 +417,10 @@ class ModelRunResult(BaseModel): class DraftRunCompareResponse(BaseModel): """多模型对比响应""" results: List[ModelRunResult] - + total_elapsed_time: float successful_count: int failed_count: int - + fastest_model: Optional[str] = None cheapest_model: Optional[str] = None diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index ccf0442f..058767d9 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -39,14 +39,14 @@ class WorkflowService: # ==================== 配置管理 ==================== def create_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]], - edges: list[dict[str, Any]], - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """创建工作流配置 @@ -109,14 +109,14 @@ class WorkflowService: return self.config_repo.get_by_app_id(app_id) def update_workflow_config( - self, - app_id: uuid.UUID, - nodes: list[dict[str, Any]] | None = None, - edges: list[dict[str, Any]] | None = None, - variables: list[dict[str, Any]] | None = None, - execution_config: dict[str, Any] | None = None, - triggers: list[dict[str, Any]] | None = None, - validate: bool = True + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True ) -> WorkflowConfig: """更新工作流配置 @@ -226,8 +226,8 @@ class WorkflowService: return config def validate_workflow_config_for_publish( - self, - app_id: uuid.UUID + self, + app_id: uuid.UUID ) -> tuple[bool, list[str]]: """验证工作流配置是否可以发布 @@ -260,13 +260,13 @@ class WorkflowService: # ==================== 执行管理 ==================== def create_execution( - self, - workflow_config_id: uuid.UUID, - app_id: uuid.UUID, - trigger_type: str, - triggered_by: uuid.UUID | None = None, - conversation_id: uuid.UUID | None = None, - input_data: dict[str, Any] | None = None + self, + workflow_config_id: uuid.UUID, + app_id: uuid.UUID, + trigger_type: str, + triggered_by: uuid.UUID | None = None, + conversation_id: uuid.UUID | None = None, + input_data: dict[str, Any] | None = None ) -> WorkflowExecution: """创建工作流执行记录 @@ -314,10 +314,10 @@ class WorkflowService: return self.execution_repo.get_by_execution_id(execution_id) def get_executions_by_app( - self, - app_id: uuid.UUID, - limit: int = 50, - offset: int = 0 + self, + app_id: uuid.UUID, + limit: int = 50, + offset: int = 0 ) -> list[WorkflowExecution]: """获取应用的执行记录列表 @@ -332,12 +332,12 @@ class WorkflowService: return self.execution_repo.get_by_app_id(app_id, limit, offset) def update_execution_status( - self, - execution_id: str, - status: str, - output_data: dict[str, Any] | None = None, - error_message: str | None = None, - error_node_id: str | None = None + self, + execution_id: str, + status: str, + output_data: dict[str, Any] | None = None, + error_message: str | None = None, + error_node_id: str | None = None ) -> WorkflowExecution: """更新执行状态 @@ -407,10 +407,10 @@ class WorkflowService: # ==================== 工作流执行 ==================== async def run( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流 @@ -527,10 +527,10 @@ class WorkflowService: ) async def run_stream( - self, - app_id: uuid.UUID, - payload: DraftRunRequest, - config: WorkflowConfig + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig ): """运行工作流(流式) @@ -600,11 +600,11 @@ class WorkflowService: # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id="", - user_id=payload.user_id + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id="", + user_id=payload.user_id ): # 直接转发 executor 的事件(已经是正确的格式) yield event @@ -626,12 +626,12 @@ class WorkflowService: } async def run_workflow( - self, - app_id: uuid.UUID, - input_data: dict[str, Any], - triggered_by: uuid.UUID, - conversation_id: uuid.UUID | None = None, - stream: bool = False + self, + app_id: uuid.UUID, + input_data: dict[str, Any], + triggered_by: uuid.UUID, + conversation_id: uuid.UUID | None = None, + stream: bool = False ) -> AsyncGenerator | dict: """运行工作流 @@ -778,12 +778,12 @@ class WorkflowService: return clean_value(event) async def _run_workflow_stream( - self, - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str): + self, + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str): """运行工作流(流式,内部方法) Args: @@ -800,11 +800,11 @@ class WorkflowService: try: async for event in execute_workflow_stream( - workflow_config=workflow_config, - input_data=input_data, - execution_id=execution_id, - workspace_id=workspace_id, - user_id=user_id + workflow_config=workflow_config, + input_data=input_data, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id ): # 直接转发事件(executor 已经返回正确格式) yield event @@ -828,7 +828,7 @@ class WorkflowService: # ==================== 依赖注入函数 ==================== def get_workflow_service( - db: Annotated[Session, Depends(get_db)] + db: Annotated[Session, Depends(get_db)] ) -> WorkflowService: """获取工作流服务(依赖注入)""" return WorkflowService(db) From 75ee591202f8b87943f2165e7ef42f944c18a834 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 16:18:25 +0800 Subject: [PATCH 10/15] style(workflow): remove unnecessary indentation --- api/app/core/workflow/nodes/end/node.py | 84 ++++++++++---------- api/app/core/workflow/nodes/llm/node.py | 3 +- api/app/core/workflow/nodes/node_factory.py | 12 +-- api/app/core/workflow/variable_pool.py | 7 +- api/app/models/models_model.py | 19 ----- api/app/services/prompt_optimizer_service.py | 2 +- 6 files changed, 57 insertions(+), 70 deletions(-) diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2e5758eb..3cece96b 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -33,7 +33,7 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) @@ -45,17 +45,17 @@ class EndNode(BaseNode): total_nodes = len(node_outputs) logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - + return output - + def _extract_referenced_nodes(self, template: str) -> list[str]: """从模板中提取引用的节点 ID - + 例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] - + Args: template: 模板字符串 - + Returns: 引用的节点 ID 列表 """ @@ -63,44 +63,44 @@ class EndNode(BaseNode): pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' matches = re.findall(pattern, template) return list(set(matches)) # 去重 - + def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: """解析模板,分离静态文本和动态引用 - + 例如:'你好 {{llm.output}}, 这是后缀' 返回:[ {"type": "static", "content": "你好 "}, {"type": "dynamic", "node_id": "llm", "field": "output"}, {"type": "static", "content": ", 这是后缀"} ] - + Args: template: 模板字符串 state: 工作流状态 - + Returns: 模板部分列表 """ import re - + parts = [] last_end = 0 - + # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) pattern = r'\{\{\s*([^}]+?)\s*\}\}' - + for match in re.finditer(pattern, template): start, end = match.span() - + # 添加前面的静态文本 if start > last_end: static_text = template[last_end:start] if static_text: parts.append({"type": "static", "content": static_text}) - + # 解析动态引用 ref = match.group(1).strip() - + # 检查是否是节点引用(如 llm.output 或 llm_qa.output) if '.' in ref: node_id, field = ref.split('.', 1) @@ -115,62 +115,62 @@ class EndNode(BaseNode): # 直接渲染这部分 rendered = self._render_template(f"{{{{{ref}}}}}", state) parts.append({"type": "static", "content": rendered}) - + last_end = end - + # 添加最后的静态文本 if last_end < len(template): static_text = template[last_end:] if static_text: parts.append({"type": "static", "content": static_text}) - + return parts - + async def execute_stream(self, state: WorkflowState): """流式执行 end 节点业务逻辑 - + 智能输出策略: 1. 检测模板中是否引用了直接上游节点 2. 如果引用了,只输出该引用**之后**的部分(后缀) 3. 前缀和引用内容已经在上游节点流式输出时发送了 - + 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' - 直接上游节点是 llm_qa - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - LLM 内容在 LLM 节点流式输出 - End 节点只输出 ' lalalalala a'(后缀,一次性输出) - + Args: state: 工作流状态 - + Yields: 完成标记 """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") - + # 获取配置的输出模板 output_template = self.config.get("output") - + if not output_template: output = "工作流已完成" yield {"__final__": True, "result": output} return - + # 找到直接上游节点 direct_upstream_nodes = [] for edge in self.workflow_config.get("edges", []): if edge.get("target") == self.node_id: source_node_id = edge.get("source") direct_upstream_nodes.append(source_node_id) - + logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") - + # 解析模板部分 parts = self._parse_template_parts(output_template, state) logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") for i, part in enumerate(parts): logger.info(f"[模板解析] part[{i}]: {part}") - + # 找到第一个引用直接上游节点的动态引用 upstream_ref_index = None for i, part in enumerate(parts): @@ -178,12 +178,12 @@ class EndNode(BaseNode): upstream_ref_index = i logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") break - + if upstream_ref_index is None: # 没有引用直接上游节点,输出完整模板内容 output = self._render_template(output_template, state) logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") - + # 通过 writer 发送完整内容(作为一个 message chunk) from langgraph.config import get_stream_writer writer = get_stream_writer() @@ -196,14 +196,14 @@ class EndNode(BaseNode): "is_suffix": False }) logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") - + # yield 完成标记 yield {"__final__": True, "result": output} return - + # 有引用直接上游节点,只输出该引用之后的部分(后缀) logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") - + # 收集后缀部分 suffix_parts = [] logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}") @@ -214,7 +214,7 @@ class EndNode(BaseNode): # 静态文本 logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") suffix_parts.append(part["content"]) - + elif part["type"] == "dynamic": # Other dynamic references (if there are multiple references) node_id = part["node_id"] @@ -229,21 +229,21 @@ class EndNode(BaseNode): except Exception as e: logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") content = "" - + # Convert to string if not None suffix_parts.append(str(content) if content is not None else "") # 拼接后缀 suffix = "".join(suffix_parts) - + # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) full_output = self._render_template(output_template, state) - + logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") - + if suffix: logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") # 一次性输出后缀(作为单个 chunk) @@ -266,8 +266,8 @@ class EndNode(BaseNode): # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") - + # yield 完成标记(包含完整输出) yield {"__final__": True, "result": full_output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 8f809923..65826d84 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig from app.db import get_db_context +from app.models import ModelType from app.services.model_service import ModelConfigService from app.core.exceptions import BusinessException @@ -136,7 +137,7 @@ class LLMNode(BaseNode): base_url=api_base, extra_params=extra_params ), - type=model_type + type=ModelType(model_type) ) logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 1abace67..2ae31d4d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,6 +7,7 @@ import logging from typing import Any, Union +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.assigner import AssignerNode logger = logging.getLogger(__name__) @@ -26,6 +28,8 @@ WorkflowNode = Union[ IfElseNode, AgentNode, TransformNode, + AssignerNode, + KnowledgeRetrievalNode, ] @@ -42,7 +46,9 @@ class NodeFactory: NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, - NodeType.IF_ELSE: IfElseNode + NodeType.IF_ELSE: IfElseNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.ASSIGNER: AssignerNode, } @classmethod @@ -82,10 +88,6 @@ class NodeFactory: """ node_type = node_config.get("type") - # 跳过条件节点(由 LangGraph 处理) - if node_type == "condition": - return None - # 获取节点类 node_class = cls._node_types.get(node_type) if not node_class: diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 1f589dab..0f97c349 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -10,7 +10,10 @@ """ import logging -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from app.core.workflow.nodes import WorkflowState logger = logging.getLogger(__name__) @@ -82,7 +85,7 @@ class VariablePool: >>> pool.set(["conv", "user_name"], "张三") """ - def __init__(self, state: dict[str, Any]): + def __init__(self, state: "WorkflowState"): """初始化变量池 Args: diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 91c1d9c7..2e60ef1c 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -15,25 +15,6 @@ class ModelType(StrEnum): EMBEDDING = "embedding" RERANK = "rerank" - @classmethod - def from_str(cls, value: str) -> "ModelType": - """ - Get a ModelType enum instance from a string value. - - Args: - value (str): The string representation of the model type. - - Returns: - ModelType: The corresponding ModelType enum object. - - Raises: - ValueError: If the given value does not match any ModelType. - """ - try: - return cls(value) - except ValueError: - raise ValueError(f"Invalid ModelType: {value}") - class ModelProvider(StrEnum): """模型提供商枚举""" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 5355474f..6af794b1 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -169,7 +169,7 @@ class PromptOptimizerService: provider=api_config.provider, api_key=api_config.api_key, base_url=api_config.api_base - ), type=ModelType.from_str(model_config.type)) + ), type=ModelType(model_config.type)) # build message messages = [ From 92c62bb46f6f562941a0592a681b9a1136e99601 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 18:49:44 +0800 Subject: [PATCH 11/15] revert(workflow): read conversation variables from database instead of API input --- api/app/schemas/app_schema.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index b6b1de52..52c5ae81 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -338,7 +338,6 @@ class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") - conversation_vars: Optional[dict[str, Any]] = Field(default=None, description="会话变量") user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") From 054a5976f5f1094e3aeacba3528e46743574a673 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 20:04:18 +0800 Subject: [PATCH 12/15] feat(workflow): add assigner node and fix circular imports with minor code style cleanup --- api/app/core/workflow/nodes/__init__.py | 6 +- .../core/workflow/nodes/assigner/__init__.py | 4 + .../core/workflow/nodes/assigner/config.py | 21 +++ api/app/core/workflow/nodes/assigner/node.py | 80 ++++++++++ api/app/core/workflow/nodes/configs.py | 4 + api/app/core/workflow/nodes/enums.py | 39 +++++ api/app/core/workflow/nodes/if_else/node.py | 2 +- api/app/core/workflow/nodes/node_factory.py | 6 +- api/app/core/workflow/nodes/operators.py | 142 ++++++++++++++++++ 9 files changed, 299 insertions(+), 5 deletions(-) create mode 100644 api/app/core/workflow/nodes/assigner/__init__.py create mode 100644 api/app/core/workflow/nodes/assigner/config.py create mode 100644 api/app/core/workflow/nodes/assigner/node.py create mode 100644 api/app/core/workflow/nodes/operators.py diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index d143c693..1d00532e 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -5,9 +5,11 @@ """ from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.if_else import IfElseNode +# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode @@ -23,5 +25,7 @@ __all__ = [ "StartNode", "EndNode", "NodeFactory", - "WorkflowNode" + "WorkflowNode", + # "KnowledgeRetrievalNode", + "AssignerNode", ] diff --git a/api/app/core/workflow/nodes/assigner/__init__.py b/api/app/core/workflow/nodes/assigner/__init__.py new file mode 100644 index 00000000..668e1aea --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig +from app.core.workflow.nodes.assigner.node import AssignerNode + +__all__ = ["AssignerNode", "AssignerNodeConfig"] diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py new file mode 100644 index 00000000..1cb0def3 --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -0,0 +1,21 @@ +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.nodes.enums import AssignmentOperator + + +class AssignerNodeConfig(BaseNodeConfig): + variable_selector: str | list[str] = Field( + ..., + description="Variables to be assigned", + ) + + operation: AssignmentOperator = Field( + ..., + description="Operator to assign", + ) + + value: str | list[str] = Field( + ..., + description="Values to assign", + ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py new file mode 100644 index 00000000..eb32bf8b --- /dev/null +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +from app.core.workflow.expression_evaluator import ExpressionEvaluator +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import AssignmentOperator +from app.core.workflow.nodes.operators import AssignmentOperatorInstance +from app.core.workflow.variable_pool import VariablePool + +logger = logging.getLogger(__name__) + + +class AssignerNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = AssignerNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the assignment operation defined by this node. + + Args: + state: The current workflow state, including conversation variables, + node outputs, and system variables. + + Returns: + None or the result of the assignment operation. + """ + # Initialize a variable pool for accessing conversation, node, and system variables + pool = VariablePool(state) + + # Get the target variable selector (e.g., "conv.test") + variable_selector = self.typed_config.variable_selector + if isinstance(variable_selector, str): + # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] + variable_selector = variable_selector.split('.') + + # Only conversation variables ('conv') are allowed + if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) + raise ValueError("Only conversation variables can be assigned.") + + # Get the value or expression to assign + value = self.typed_config.value + if isinstance(value, list): + value = '.'.join(value) + value = ExpressionEvaluator.evaluate( + expression=value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + + # Select the appropriate assignment operator instance based on the target variable type + operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + pool, variable_selector, value + ) + + # Execute the configured assignment operation + match self.typed_config.operation: + case AssignmentOperator.ASSIGN: + operator.assign() + case AssignmentOperator.CLEAR: + operator.clear() + case AssignmentOperator.ADD: + operator.add() + case AssignmentOperator.SUBTRACT: + operator.subtract() + case AssignmentOperator.MULTIPLY: + operator.multiply() + case AssignmentOperator.DIVIDE: + operator.divide() + case AssignmentOperator.APPEND: + operator.append() + case AssignmentOperator.REMOVE_FIRST: + operator.remove_first() + case AssignmentOperator.REMOVE_LAST: + operator.remove_last() + case _: + raise ValueError(f"Invalid Operator: {self.typed_config.operation}") diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 15ab0ce9..ecded070 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig +# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.assigner.config import AssignerNodeConfig __all__ = [ # 基础类 @@ -28,4 +30,6 @@ __all__ = [ "AgentNodeConfig", "TransformNodeConfig", "IfElseNodeConfig", + # "KnowledgeRetrievalNodeConfig", + "AssignerNodeConfig", ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index af5ddbaa..82ecad5d 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,5 +1,14 @@ from enum import StrEnum +from app.core.workflow.nodes.operators import ( + StringOperator, + NumberOperator, + AssignmentOperatorType, + BooleanOperator, + ArrayOperator, + ObjectOperator +) + class NodeType(StrEnum): START = "start" @@ -14,6 +23,7 @@ class NodeType(StrEnum): HTTP_REQUEST = "http-request" TOOL = "tool" AGENT = "agent" + ASSIGNER = "assigner" class ComparisonOperator(StrEnum): @@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum): class LogicOperator(StrEnum): AND = "and" OR = "or" + + +class AssignmentOperator(StrEnum): + ASSIGN = "assign" + CLEAR = "clear" + + ADD = "add" # += + SUBTRACT = "subtract" # -= + MULTIPLY = "multiply" # *= + DIVIDE = "divide" # /= + + APPEND = "append" + REMOVE_LAST = "remove_last" + REMOVE_FIRST = "remove_first" + + @classmethod + def get_operator(cls, obj) -> AssignmentOperatorType: + if isinstance(obj, str): + return StringOperator + elif isinstance(obj, bool): + return BooleanOperator + elif isinstance(obj, (int, float)): + return NumberOperator + elif isinstance(obj, list): + return ArrayOperator + elif isinstance(obj, dict): + return ObjectOperator + + raise TypeError(f"Unsupported variable type ({type(obj)})") diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index ed3dbbd6..aedf0727 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -1,7 +1,7 @@ import logging from typing import Any -from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import ConditionDetail diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 2ae31d4d..93364083 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,7 +7,7 @@ import logging from typing import Any, Union -from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -29,7 +29,7 @@ WorkflowNode = Union[ AgentNode, TransformNode, AssignerNode, - KnowledgeRetrievalNode, + # KnowledgeRetrievalNode, ] @@ -47,7 +47,7 @@ class NodeFactory: NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + # NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, } diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py new file mode 100644 index 00000000..9563c335 --- /dev/null +++ b/api/app/core/workflow/nodes/operators.py @@ -0,0 +1,142 @@ +from abc import ABC +from typing import Union, Type + +from app.core.workflow.variable_pool import VariablePool + + +class OperatorBase(ABC): + def __init__(self, pool: VariablePool, left_selector, right): + self.pool = pool + self.left_selector = left_selector + self.right = right + + self.type_limit: type[str, int, dict, list] = None + + def check(self, no_right=False): + left = self.pool.get(self.left_selector) + if not isinstance(left, self.type_limit): + raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") + + if not no_right and not isinstance(self.right, self.type_limit): + raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") + + +class StringOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = str + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, '') + + +class NumberOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = (float, int) + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, 0) + + def add(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin + self.right) + + def subtract(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin - self.right) + + def multiply(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin * self.right) + + def divide(self) -> None: + self.check() + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin / self.right) + + +class BooleanOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = bool + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, False) + + +class ArrayOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = list + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, list()) + + def append(self) -> None: + self.check() + # TODO:require type limit in list + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin.append(self.right)) + + def extend(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin.extend(self.right)) + + def remove_last(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin.pop()) + + def remove_first(self) -> None: + self.check(no_right=True) + origin = self.pool.get(self.left_selector) + self.pool.set(self.left_selector, origin.pop(0)) + + +class ObjectOperator(OperatorBase): + def __init__(self, pool: VariablePool, left_selector, right): + super().__init__(pool, left_selector, right) + self.type_limit = object + + def assign(self) -> None: + self.check() + self.pool.set(self.left_selector, self.right) + + def clear(self) -> None: + self.check(no_right=True) + self.pool.set(self.left_selector, dict()) + + +AssignmentOperatorInstance = Union[ + StringOperator, + NumberOperator, + BooleanOperator, + ArrayOperator, + ObjectOperator +] +AssignmentOperatorType = Type[AssignmentOperatorInstance] From 6b2c69ebf486e2f4c8c51023f17e40d182e83230 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 20:23:20 +0800 Subject: [PATCH 13/15] fix(workflow): fix incorrect list append/pop logic in assigner node --- api/app/core/workflow/nodes/operators.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 9563c335..5a1dba1f 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -100,7 +100,8 @@ class ArrayOperator(OperatorBase): self.check() # TODO:require type limit in list origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin.append(self.right)) + origin.append(self.right) + self.pool.set(self.left_selector, origin) def extend(self) -> None: self.check(no_right=True) @@ -110,12 +111,14 @@ class ArrayOperator(OperatorBase): def remove_last(self) -> None: self.check(no_right=True) origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin.pop()) + origin.pop() + self.pool.set(self.left_selector, origin) def remove_first(self) -> None: self.check(no_right=True) origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin.pop(0)) + origin.pop(0) + self.pool.set(self.left_selector, origin) class ObjectOperator(OperatorBase): From 0ad5d1b662a1aca8fbf8f3f0b071f6823f5bf5a3 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 20:26:30 +0800 Subject: [PATCH 14/15] fix(workflow): fix incorrect list extend logic in assigner node --- api/app/core/workflow/nodes/operators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 5a1dba1f..de215460 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -106,7 +106,8 @@ class ArrayOperator(OperatorBase): def extend(self) -> None: self.check(no_right=True) origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin.extend(self.right)) + origin.extend(self.right) + self.pool.set(self.left_selector, origin) def remove_last(self) -> None: self.check(no_right=True) From 28e88c2d4c299ff100935776fa725d26d0f7a36b Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 22 Dec 2025 20:37:42 +0800 Subject: [PATCH 15/15] fix(workflow): fix incorrect list append logic in assigner node --- api/app/core/workflow/nodes/operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index de215460..a80cf326 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -97,7 +97,7 @@ class ArrayOperator(OperatorBase): self.pool.set(self.left_selector, list()) def append(self) -> None: - self.check() + self.check(no_right=True) # TODO:require type limit in list origin = self.pool.get(self.left_selector) origin.append(self.right)