diff --git a/api/app/celery_app.py b/api/app/celery_app.py index d072a346..ce7e9300 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -83,17 +83,18 @@ celery_app.autodiscover_tasks(['app']) reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) - +workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME # 构建定时任务配置 beat_schedule_config = { - "run-reflection-engine": { - "task": "app.core.memory.agent.reflection.timer", - "schedule": reflection_schedule, - "args": (), - }, - "check-read-service": { - "task": "app.core.memory.agent.health.check_read_service", - "schedule": health_schedule, + + # "check-read-service": { + # "task": "app.core.memory.agent.health.check_read_service", + # "schedule": health_schedule, + # "args": (), + # }, + "run-workspace-reflection": { + "task": "app.tasks.workspace_reflection_task", + "schedule": workspace_reflection_schedule, "args": (), }, } diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index a3caaf4a..ddf534c6 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -23,12 +23,13 @@ from . import ( memory_dashboard_controller, memory_storage_controller, memory_dashboard_controller, + memory_reflection_controller, api_key_controller, release_share_controller, public_share_controller, multi_agent_controller, workflow_controller, - prompt_optimizer_controller + prompt_optimizer_controller, ) # 创建管理端 API 路由器 @@ -60,5 +61,5 @@ manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) manager_router.include_router(workflow_controller.router) manager_router.include_router(prompt_optimizer_controller.router) - +manager_router.include_router(memory_reflection_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py new file mode 100644 index 00000000..759c25c5 --- /dev/null +++ b/api/app/controllers/memory_reflection_controller.py @@ -0,0 +1,200 @@ +import asyncio + +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from sqlalchemy import text + +from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine +from app.dependencies import get_current_user +from app.db import get_db +from app.models.user_model import User +from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService + +from app.schemas.memory_reflection_schemas import Memory_Reflection + +load_dotenv() +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory", + tags=["Memory"], +) + + +@router.post("/reflection/save") +async def save_reflection_config( + request: Memory_Reflection, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Save reflection configuration to data_comfig table""" + + + + try: + config_id = request.config_id + if not config_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="缺少必需参数: config_id" + ) + + api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") + + update_params = { + "enable_self_reflexion": request.reflectionenabled, + "iteration_period": request.reflection_period_in_hours, + "reflexion_range": request.reflexion_range, + "baseline": request.baseline, + "reflection_model_id": request.reflection_model_id, + "memory_verify": request.memory_verify, + "quality_assessment": request.quality_assessment, + } + + + + query, params = DataConfigRepository.build_update_reflection(config_id, **update_params) + + result = db.execute(text(query), params) + if result.rowcount == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到config_id为 {config_id} 的配置" + ) + + db.commit() + + # 查询更新后的配置 + select_query, select_params = DataConfigRepository.build_select_reflection(config_id) + result = db.execute(text(select_query), select_params).fetchone() + + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"更新后未找到config_id为 {config_id} 的配置" + ) + + api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") + + # 返回结果 + return { + "status": "成功", + "message": "反思配置已保存", + "config_id": config_id, + "database_record": { + "config_id": result.config_id, + "enable_self_reflexion": result.enable_self_reflexion, + "iteration_period": result.iteration_period, + "reflexion_range": result.reflexion_range, + "baseline": result.baseline, + "reflection_model_id": result.reflection_model_id, + "memory_verify": result.memory_verify, + "quality_assessment": result.quality_assessment, + "user_id": result.user_id + } + } + + except ValueError as ve: + api_logger.error(f"参数错误: {str(ve)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"参数错误: {str(ve)}" + ) + except Exception as e: + api_logger.error(f"反思配置保存失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"反思配置保存失败: {str(e)}" + ) + + +@router.post("/reflection") +async def start_workspace_reflection( + request: dict, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Activate the reflection function for all matching applications in the workspace""" + workspace_id = current_user.current_workspace_id + reflection_service = MemoryReflectionService(db) + + try: + api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}") + + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(workspace_id) + + reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + return { + "status": "完成", + "message": f"成功处理 {len(reflection_results)} 个反思任务", + "workspace_id": str(workspace_id), + "reflection_count": len(reflection_results), + "reflection_results": reflection_results + } + + except Exception as e: + api_logger.error(f"启动workspace反思失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"启动workspace反思失败: {str(e)}" + ) + +@router.post("/reflection/run") +async def reflection_run( + reflection: Memory_Reflection, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """Activate the reflection function for all matching applications in the workspace""" + config = ReflectionConfig( + enabled=reflection.reflectionenabled, + iteration_period=reflection.reflection_period_in_hours, + reflexion_range=reflection.reflexion_range, + baseline=reflection.baseline, + output_example='', + memory_verify=reflection.memory_verify, + quality_assessment=reflection.quality_assessment, + violation_handling_strategy="block", + model_id=reflection.reflection_model_id + ) + connector = Neo4jConnector() + engine = ReflectionEngine( + config=config, + neo4j_connector=connector, + llm_client=reflection.reflection_model_id # 传入 model_id + ) + + result=await (engine.reflection_run()) + return result diff --git a/api/app/core/config.py b/api/app/core/config.py index 48f79d5e..41e9f0cf 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -148,6 +148,7 @@ class Settings: HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) + REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") diff --git a/api/app/core/memory/storage_services/reflection_engine/example/example.json b/api/app/core/memory/storage_services/reflection_engine/example/example.json new file mode 100644 index 00000000..6528da60 --- /dev/null +++ b/api/app/core/memory/storage_services/reflection_engine/example/example.json @@ -0,0 +1,210 @@ +{ + "memory_verify": { + "source_data": [ + { + "statement_name": "用户是2023年春天去北京工作的。", + "statement_id": "62beac695b1346f4871740a45db88782", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户后来基本一直都在北京上班。", + "statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户从2023年开始就一直在北京生活。", + "statement_id": "e612a44da4db483993c350df7c97a1a1", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户从来没有长期离开过北京。", + "statement_id": "b3c787a2e33c49f7981accabbbb4538a", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。", + "statement_id": "64cde4230cb24a4da726e7db9e7aa616", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。", + "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户的银行卡号是6222023847595898。", + "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户的身份信息和银行卡信息一直没变。", + "statement_id": "b3ca618e1e204b83bebd70e75cf2073f", + "statement_created_at": "2025-12-19T10:31:15.239252" + }, + { + "statement_name": "用户认为在上海的那段时间更多算是远程配合。", + "statement_id": "150af89d2c154e6eb41ff1a91e37f962", + "statement_created_at": "2025-12-19T10:31:15.239252" + } + ], + "databasets": [ + { + "entity1_name": "Person", + "description": "表示人类个体的通用类型", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "用户", + "entity2": { + "entity_idx": 0, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Person", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "用户", + "apply_id": "88a459f5_text08", + "id": "3d3896797b334572a80d57590026063d" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "身份信息", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "Strong", + "description": "用于个人身份识别的数据", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Information", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "身份信息", + "apply_id": "88a459f5_text08", + "id": "aa766a517e82490599a9b3af54cfd933" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "6222023847595898", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "Strong", + "description": "用户的银行卡号码", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Numeric", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "6222023847595898", + "apply_id": "88a459f5_text08", + "id": "610ba361918f4e68a65ce6ad06e5c7a0" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "上海办公室", + "entity2": { + "entity_idx": 1, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "aliases": ["上海办"], + "connect_strength": "Strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "位于上海的工作办公场所", + "statement_id": "8b1b12e23b844b8088dfeb67da6ad669", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Location", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "上海办公室", + "apply_id": "88a459f5_text08", + "id": "fb702ef695c14e14af3e56786bc8815b" + } + }, + { + "entity1_name": "用户", + "description": "叙述者,讲述个人工作与生活经历的个体", + "statement_id": "62beac695b1346f4871740a45db88782", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "北京", + "entity2": { + "entity_idx": 2, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "aliases": ["京", "京城", "北平"], + "connect_strength": "strong", + "created_at": "2025-12-19T10:31:15.239252000", + "description": "中国的首都城市,用户主要工作和生活所在地", + "statement_id": "62beac695b1346f4871740a45db88782", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Location", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "北京", + "apply_id": "88a459f5_text08", + "id": "81b2d1a571bb46a08a2d7a1e87efb945" + } + }, + { + "entity1_name": "11010119950308123X", + "description": "具体的身份证号码值", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "created_at": "2025-12-19T10:31:15.239252000", + "expired_at": "9999-12-31T00:00:00.000000000", + "relationship_type": "EXTRACTED_RELATIONSHIP", + "relationship": {}, + "entity2_name": "身份证号", + "entity2": { + "entity_idx": 2, + "run_id": "62b59cfebeea43dd94d91763056f069a", + "connect_strength": "strong", + "description": "中华人民共和国公民的身份号码", + "created_at": "2025-12-19T10:31:15.239252000", + "statement_id": "030afd362e9b4110b139e68e5d3e7143", + "expired_at": "9999-12-31T00:00:00.000000000", + "entity_type": "Identifier", + "group_id": "88a459f5_text08", + "user_id": "88a459f5_text08", + "name": "身份证号", + "apply_id": "88a459f5_text08", + "id": "3e5f920645b2404fadb0e9ff60d1306e" + } + } + ] + } +} \ No newline at end of file diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index b3e5813d..8f5b9bae 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -8,17 +8,20 @@ 4. 反思结果应用 - 更新记忆库 """ -import os import json import logging import asyncio +import os +import time from typing import List, Dict, Any, Optional -from datetime import datetime from enum import Enum import uuid -from pydantic import BaseModel, Field +from pydantic import BaseModel +from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all +from app.repositories.neo4j.neo4j_update import neo4j_data +from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 配置日志 _root_logger = logging.getLogger() @@ -33,14 +36,14 @@ else: class ReflectionRange(str, Enum): """反思范围枚举""" - RETRIEVAL = "retrieval" # 从检索结果中反思 - DATABASE = "database" # 从整个数据库中反思 + PARTIAL = "partial" # 从检索结果中反思 + ALL = "all" # 从整个数据库中反思 class ReflectionBaseline(str, Enum): """反思基线枚举""" - TIME = "TIME" # 基于时间的反思 - FACT = "FACT" # 基于事实的反思 + TIME = "TIME" # 基于时间的反思 + FACT = "FACT" # 基于事实的反思 HYBRID = "HYBRID" # 混合反思 @@ -48,9 +51,16 @@ class ReflectionConfig(BaseModel): """反思引擎配置""" enabled: bool = False iteration_period: str = "3" # 反思周期 - reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL + reflexion_range: ReflectionRange = ReflectionRange.PARTIAL baseline: ReflectionBaseline = ReflectionBaseline.TIME - concurrency: int = Field(default=5, description="并发数量") + model_id: Optional[str] = None # 模型ID + end_user_id: Optional[str] = None + output_example: Optional[str] = None # 输出示例 + + # 评估相关字段 + memory_verify: bool = True # 记忆验证 + quality_assessment: bool = True # 质量评估 + violation_handling_strategy: str = "warn" # 违规处理策略 class Config: use_enum_values = True @@ -75,16 +85,16 @@ class ReflectionEngine: """ def __init__( - self, - config: ReflectionConfig, - neo4j_connector: Optional[Any] = None, - llm_client: Optional[Any] = None, - get_data_func: Optional[Any] = None, - render_evaluate_prompt_func: Optional[Any] = None, - render_reflexion_prompt_func: Optional[Any] = None, - conflict_schema: Optional[Any] = None, - reflexion_schema: Optional[Any] = None, - update_query: Optional[str] = None + self, + config: ReflectionConfig, + neo4j_connector: Optional[Any] = None, + llm_client: Optional[Any] = None, + get_data_func: Optional[Any] = None, + render_evaluate_prompt_func: Optional[Any] = None, + render_reflexion_prompt_func: Optional[Any] = None, + conflict_schema: Optional[Any] = None, + reflexion_schema: Optional[Any] = None, + update_query: Optional[str] = None ): """ 初始化反思引擎 @@ -109,7 +119,7 @@ class ReflectionEngine: self.conflict_schema = conflict_schema self.reflexion_schema = reflexion_schema self.update_query = update_query - self._semaphore = asyncio.Semaphore(config.concurrency) + self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 # 延迟导入以避免循环依赖 self._lazy_init_done = False @@ -127,11 +137,21 @@ class ReflectionEngine: from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.utils.config import definitions as config_defs self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) + elif isinstance(self.llm_client, str): + # 如果 llm_client 是字符串(model_id),则用它初始化客户端 + from app.core.memory.utils.llm.llm_utils import get_llm_client + model_id = self.llm_client + self.llm_client = get_llm_client(model_id) if self.get_data_func is None: from app.core.memory.utils.config.get_data import get_data self.get_data_func = get_data + # 导入get_data_statement函数 + if not hasattr(self, 'get_data_statement'): + from app.core.memory.utils.config.get_data import get_data_statement + self.get_data_statement = get_data_statement + if self.render_evaluate_prompt_func is None: from app.core.memory.utils.prompt.template_render import render_evaluate_prompt self.render_evaluate_prompt_func = render_evaluate_prompt @@ -154,13 +174,11 @@ class ReflectionEngine: self._lazy_init_done = True - async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult: + async def execute_reflection(self, host_id) -> ReflectionResult: """ 执行完整的反思流程 - Args: host_id: 主机ID - Returns: ReflectionResult: 反思结果 """ @@ -176,9 +194,10 @@ class ReflectionEngine: start_time = asyncio.get_event_loop().time() logging.info("====== 自我反思流程开始 ======") + print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment) try: # 1. 获取反思数据 - reflexion_data = await self._get_reflexion_data(host_id) + reflexion_data, statement_databasets = await self._get_reflexion_data(host_id) if not reflexion_data: return ReflectionResult( success=True, @@ -187,22 +206,21 @@ class ReflectionEngine: ) # 2. 检测冲突(基于事实的反思) - conflict_data = await self._detect_conflicts(reflexion_data) - if not conflict_data: - return ReflectionResult( - success=True, - message="无冲突,无需反思", - execution_time=asyncio.get_event_loop().time() - start_time - ) + conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) + print(100 * '-') + print(conflict_data) + print(100 * '-') - conflicts_found = len(conflict_data) - logging.info(f"发现 {conflicts_found} 个冲突") + # 检查是否真的有冲突 + has_conflict = conflict_data[0].get('conflict', False) + conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 + logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") # 记录冲突数据 await self._log_data("conflict", conflict_data) # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data) + solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) if not solved_data: return ReflectionResult( success=False, @@ -210,6 +228,9 @@ class ReflectionEngine: conflicts_found=conflicts_found, execution_time=asyncio.get_event_loop().time() - start_time ) + print(100 * '*') + print(solved_data) + print(100 * '*') conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") @@ -230,7 +251,8 @@ class ReflectionEngine: conflicts_found=conflicts_found, conflicts_resolved=conflicts_resolved, memories_updated=memories_updated, - execution_time=execution_time + execution_time=execution_time, + ) except Exception as e: @@ -241,6 +263,79 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) + async def reflection_run(self): + self._lazy_init() + start_time = time.time() + + asyncio.get_event_loop().time() + logging.info("====== 自我反思流程开始 ======") + + result_data = {} + + source_data, databasets = await self.extract_fields_from_json() + result_data['baseline'] = self.config.baseline + result_data[ + 'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" + + # 2. 检测冲突(基于事实的反思) + conflict_data = await self._detect_conflicts(databasets, source_data) + # 遍历数据提取字段 + quality_assessments = [] + memory_verifies = [] + for item in conflict_data: + print(item) + quality_assessments.append(item['quality_assessment']) + memory_verifies.append(item['memory_verify']) + result_data['quality_assessments'] = quality_assessments + result_data['memory_verifies'] = memory_verifies + + # 检查是否真的有冲突 + has_conflict = conflict_data[0].get('conflict', False) + conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0 + logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突") + + # 记录冲突数据 + await self._log_data("conflict", conflict_data) + + # 3. 解决冲突 + solved_data = await self._resolve_conflicts(conflict_data, source_data) + if not solved_data: + return ReflectionResult( + success=False, + message="反思失败,未解决冲突", + conflicts_found=conflicts_found, + execution_time=asyncio.get_event_loop().time() - start_time + ) + reflexion_data = [] + + # 遍历数据提取reflexion字段 + for item in solved_data: + if 'results' in item: + for result in item['results']: + reflexion_data.append(result['reflexion']) + result_data['reflexion_data'] = reflexion_data + execution_time = time.time() - start_time + return {"status": "SUCCESS", "message": "反思试运行", "data": result_data, "time": execution_time} + + async def extract_fields_from_json(self): + """从example.json中提取source_data和databasets字段""" + + prompt_dir = os.path.join(os.path.dirname(__file__), "example") + try: + # 读取JSON文件 + with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f: + data = json.loads(f.read()) + + # 提取memory_verify下的字段 + memory_verify = data.get("memory_verify", {}) + source_data = memory_verify.get("source_data", []) + databasets = memory_verify.get("databasets", []) + + return source_data, databasets + + except Exception as e: + return [], [] + async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: """ 获取反思数据 @@ -253,17 +348,28 @@ class ReflectionEngine: Returns: List[Any]: 反思数据列表 """ - if self.config.reflexion_range == ReflectionRange.RETRIEVAL: - # 从检索结果中获取数据 - return await self.get_data_func(host_id) - elif self.config.reflexion_range == ReflectionRange.DATABASE: - # 从整个数据库中获取数据(待实现) - logging.warning("从数据库获取反思数据功能尚未实现") - return [] - else: - raise ValueError(f"未知的反思范围: {self.config.reflexion_range}") - async def _detect_conflicts(self, data: List[Any]) -> List[Any]: + + + if self.config.reflexion_range == ReflectionRange.PARTIAL: + neo4j_query = neo4j_query_part.format(host_id) + neo4j_statement = neo4j_statement_part.format(host_id) + elif self.config.reflexion_range == ReflectionRange.ALL: + neo4j_query = neo4j_query_all.format(host_id) + neo4j_statement = neo4j_statement_all.format(host_id) + try: + result = await self.neo4j_connector.execute_query(neo4j_query) + result_statement = await self.neo4j_connector.execute_query(neo4j_statement) + neo4j_databasets = await self.get_data_func(result) + neo4j_state = await self.get_data_statement(result_statement) + return neo4j_databasets, neo4j_state + + + except Exception as e: + logging.error(f"Neo4j查询失败: {e}") + return [], [] + + async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]: """ 检测冲突(基于事实的反思) @@ -278,14 +384,28 @@ class ReflectionEngine: if not data: return [] + # 数据预处理:如果数据量太少,直接返回无冲突 + if len(data) < 2: + logging.info("数据量不足,无需检测冲突") + return [] + + # 使用转换后的数据 + print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 + memory_verify = self.config.memory_verify + logging.info("====== 冲突检测开始 ======") start_time = asyncio.get_event_loop().time() + quality_assessment = self.config.quality_assessment try: # 渲染冲突检测提示词 rendered_prompt = await self.render_evaluate_prompt_func( data, - self.conflict_schema + self.conflict_schema, + self.config.baseline, + memory_verify, + quality_assessment, + statement_databasets ) messages = [{"role": "user", "content": rendered_prompt}] @@ -316,7 +436,7 @@ class ReflectionEngine: logging.error(f"冲突检测失败: {e}", exc_info=True) return [] - async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]: + async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]: """ 解决冲突 @@ -332,6 +452,8 @@ class ReflectionEngine: return [] logging.info("====== 冲突解决开始 ======") + baseline = self.config.baseline + memory_verify = self.config.memory_verify # 并行处理每个冲突 async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: @@ -341,7 +463,10 @@ class ReflectionEngine: # 渲染反思提示词 rendered_prompt = await self.render_reflexion_prompt_func( [conflict], - self.reflexion_schema + self.reflexion_schema, + baseline, + memory_verify, + statement_databasets ) messages = [{"role": "user", "content": rendered_prompt}] @@ -381,8 +506,8 @@ class ReflectionEngine: return solved async def _apply_reflection_results( - self, - solved_data: List[Dict[str, Any]] + self, + solved_data: List[Dict[str, Any]] ) -> int: """ 应用反思结果(更新记忆库) @@ -395,57 +520,7 @@ class ReflectionEngine: Returns: int: 成功更新的记忆数量 """ - if not solved_data: - logging.warning("无解决方案数据,跳过更新") - return 0 - - logging.info("====== 记忆更新开始 ======") - - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - """更新单条记忆""" - async with self._semaphore: - try: - if not isinstance(item, dict): - return False - - # 提取更新参数 - resolved = item.get("resolved", {}) - resolved_mem = resolved.get("resolved_memory", {}) - group_id = resolved_mem.get("group_id") - memory_id = resolved_mem.get("id") - new_invalid_at = resolved_mem.get("invalid_at") - - if not all([group_id, memory_id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - - # 执行更新 - await self.neo4j_connector.execute_query( - self.update_query, - group_id=group_id, - id=memory_id, - new_invalid_at=new_invalid_at, - ) - - return True - - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - # 并发执行所有更新任务 - tasks = [ - _update_one(item) - for item in solved_data - if isinstance(item, dict) - ] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆") - + success_count = await neo4j_data(solved_data) return success_count async def _log_data(self, label: str, data: Any) -> None: @@ -456,6 +531,7 @@ class ReflectionEngine: label: 数据标签 data: 要记录的数据 """ + def _write(): try: with open("reflexion_data.json", "a", encoding="utf-8") as f: @@ -470,9 +546,9 @@ class ReflectionEngine: # 基于时间的反思方法 async def time_based_reflection( - self, - host_id: uuid.UUID, - time_period: Optional[str] = None + self, + host_id: uuid.UUID, + time_period: Optional[str] = None ) -> ReflectionResult: """ 基于时间的反思 @@ -494,8 +570,8 @@ class ReflectionEngine: # 基于事实的反思方法 async def fact_based_reflection( - self, - host_id: uuid.UUID + self, + host_id: uuid.UUID ) -> ReflectionResult: """ 基于事实的反思 @@ -515,8 +591,8 @@ class ReflectionEngine: # 综合反思方法 async def comprehensive_reflection( - self, - host_id: uuid.UUID + self, + host_id: uuid.UUID ) -> ReflectionResult: """ 综合反思 @@ -553,33 +629,3 @@ class ReflectionEngine: else: raise ValueError(f"未知的反思基线: {self.config.baseline}") - -# 便捷函数:创建默认配置的反思引擎 -def create_reflection_engine( - enabled: bool = False, - iteration_period: str = "3", - reflexion_range: str = "retrieval", - baseline: str = "TIME", - concurrency: int = 5 -) -> ReflectionEngine: - """ - 创建反思引擎实例 - - Args: - enabled: 是否启用反思 - iteration_period: 反思周期 - reflexion_range: 反思范围 - baseline: 反思基线 - concurrency: 并发数量 - - Returns: - ReflectionEngine: 反思引擎实例 - """ - config = ReflectionConfig( - enabled=enabled, - iteration_period=iteration_period, - reflexion_range=reflexion_range, - baseline=baseline, - concurrency=concurrency - ) - return ReflectionEngine(config) diff --git a/api/app/core/memory/utils/config/get_data.py b/api/app/core/memory/utils/config/get_data.py index f2f21198..a099694e 100644 --- a/api/app/core/memory/utils/config/get_data.py +++ b/api/app/core/memory/utils/config/get_data.py @@ -1,13 +1,8 @@ import json -import os import uuid -from typing import List, Dict, Any, Optional -from sqlalchemy.orm import Session -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.schemas.memory_storage_schema import BaseDataSchema - import logging + +from typing import List, Dict, Any logger = logging.getLogger(__name__) async def _load_(data: List[Any]) -> List[Dict]: @@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]: return results -async def get_data(host_id: uuid.UUID) -> List[Dict]: +async def get_data(result): """ 从数据库中获取数据 """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all() + neo4j_databasets=[] + for item in result: + filtered_item = {} + for key, value in item.items(): + if 'name_embedding' not in key.lower(): + if key == 'relationship' and value is not None: + # 只保留relationship的指定字段 + rel_filtered = {} + if hasattr(value, 'get'): + rel_filtered['run_id'] = value.get('run_id') + rel_filtered['statement'] = value.get('statement') + rel_filtered['statement_id'] = value.get('statement_id') + rel_filtered['expired_at'] = value.get('expired_at') + rel_filtered['created_at'] = value.get('created_at') + filtered_item[key] = rel_filtered + elif key == 'entity2' and value is not None: + # 过滤entity2的name_embedding字段 + entity2_filtered = {} + if hasattr(value, 'items'): + for e_key, e_value in value.items(): + if 'name_embedding' not in e_key.lower(): + entity2_filtered[e_key] = e_value + filtered_item[key] = entity2_filtered + else: + filtered_item[key] = value + + # 直接将字典添加到列表中 + neo4j_databasets.append(filtered_item) + return neo4j_databasets +async def get_data_statement( result): + neo4j_databasets=[] + for i in result: + neo4j_databasets.append(i) + return neo4j_databasets + - # print(f"data:\n{data}") - # 解析,提取为字典的列表 - results = await _load_(data) - return results - except Exception as e: - logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass if __name__ == "__main__": diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index cb5b917d..e1ecf820 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -1,19 +1,222 @@ -你将收到一组记忆对象:{{ evaluate_data }}。 -任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突) +你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数: +原本的输入句子:{{statement_databasets}} +需要检测冲突对象:{{ evaluate_data }} +冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID) +记忆审核开关:{{ memory_verify }}(取值为 true / false) +记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false) -仅输出一个合法 JSON 对象,严格遵循下述结构: +你的任务是: +对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果 +数据的结构: + statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, + 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估) +## 冲突定义 + +### 时间冲突 +时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾: + +1. **同一活动的时间冲突**: + - 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球") + - 同一用户在同一时间段内被记录进行不同的互斥活动 + +2. **时间逻辑错误**: + - expired_at 早于 created_at + - 同一事实的 created_at 时间差异超过合理误差范围(>5分钟) + +3. **日期属性冲突**: + - 同一人的生日记录为不同日期(如"2月10号"和"2月16号") +4.存在明确先后约束 A -> B,但 t(A) > t(B) + -例:入学时间晚于毕业时间。 + -处理:标记异常、降权、触发逻辑反思或人工审查。 +5.时间属性冲突 + -单值日期属性出现多值(生日、入职日期) + -注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。 +6.互斥重叠冲突 + -例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地) + -处理:证据仲裁、保留多版本(active + candidate)。 + + + +### 事实冲突 +事实冲突是指同一实体的属性或关系存在相互矛盾的陈述: + +1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) +2. **关系矛盾**:同一实体在相同语境下的不同关系描述 +3. **身份冲突**:同一实体被赋予不同的类型或角色 + +### 混合冲突检测 +检测所有类型的冲突,包括但不限于时间冲突和事实冲突: +检测任何逻辑上不一致或相互矛盾的记录 +## 记忆审核定义 + +### 隐私信息检测(隐私冲突) +当memory_verify为true时,需要额外检测包含个人隐私信息的记录: + +1. **身份证信息**:包含身份证号码、身份证相关描述 +2. **手机号码**:包含手机号、电话号码等联系方式 +3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 +4. **银行信息**:包含银行卡号、账户信息、支付信息 +5. **税务信息**:包含税号、纳税信息、发票信息 +6. **贷款信息**:包含贷款记录、信贷信息、借款信息 +7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 + +### 隐私检测原则 +- 检测description、entity1_name、entity2_name等字段中的隐私信息 +- 识别数字模式(如手机号11位数字、身份证18位等) +- 识别关键词(如"身份证"、"银行卡"、"密码"等) +- 检测敏感实体类型和关系 + +## 冲突检测原则 + +**全面检测**:不区分冲突类型,检测所有可能的冲突 +**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段 +**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录 +**语义分析**:分析description字段的语义相似性和冲突性 +**时间逻辑**:检查时间字段的逻辑一致性 +**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录 + +## 不符合冲突检测 + -称呼 +## 重要检测示例 + +### 冲突检测示例 +- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号) +- 同一实体的重复定义但描述不同 +- 同一关系的不同表述但含义冲突 +- 任何逻辑上不可能同时为真的记录 + +### 隐私信息检测示例 +- 包含手机号的记录:"用户的手机号是13812345678" +- 包含身份证的记录:"身份证号码为110101199001011234" +- 包含银行卡的记录:"银行卡号6222021234567890" +- 包含社交账号的记录:"微信号是user123456" +- 包含敏感信息的实体名称或描述 + +## 输出要求 + +**关键原则**: +1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录 +2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中 +3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中 +4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组 +5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null +6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响 +7. 不输出conflict_memory字段 + +**处理逻辑**: +- 首先进行冲突检测,将冲突记录加入data数组 +- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组 +- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果 +- 最终data数组包含所有冲突记录和隐私信息记录(去重) +- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果 +- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述 + +返回数据格式以json方式输出: +- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 +- 关键的JSON格式要求{"statement":识别出的文本内容} +1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号 +2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们 +3.确保所有JSON字符串都正确关闭并以逗号分隔 +4.JSON字符串值中不包括换行符 +5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\"" +6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` + +## 记忆质量评估定义 + +### 质量评估标准 +当quality_assessment为true时,需要对记忆数据进行质量评估: + +1. **数据完整性**: + - 检查必要字段是否完整(entity1_name、entity2_name、description等) + - 检查关系描述是否清晰明确 + - 检查时间字段的有效性 + +2. **重复字段检测**: + - 识别相同或高度相似的记录 + - 检测冗余的实体关系 + - 分析描述内容的重复度 + +3. **无意义字段检测**: + - 识别空值、无效值或占位符内容 + - 检测过于简单或无信息量的描述 + - 识别格式错误或不规范的数据 + +4. **上下文依赖性**: + - 评估记录是否需要额外上下文才能理解 + - 检查实体名称的明确性 + - 分析关系描述的自包含性 + +### 质量评估输出 +- **质量百分比**:基于上述标准计算的整体质量分数(0-100) +- **质量概述**:简要描述数据质量状况,包括主要问题和优点 + +输出是仅输出一个合法 JSON 对象,严格遵循下述结构: { - "data": [ ...与输入同结构的记忆对象数组... ], - "conflict": true 或 false, - "conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null + "data": [ + { + "entity1_name": "实体1名称", + "description": "描述信息", + "statement_id": "陈述ID", + "created_at": "创建时间戳", + "expired_at": "过期时间戳", + "relationship_type": "关系类型", + "relationship": "关系对象", + "entity2_name": "实体2名称", + "entity2": "实体2对象" + } + ], + "conflict": true或false, + "quality_assessment": { + "score": 质量百分比数字, + "summary": "质量概述文本" + } 或 null, + "memory_verify": { + "has_privacy": true或false, + "privacy_types": ["检测到的隐私信息类型列表"], + "summary": "隐私检测结果概述" + } 或 null } 必须遵守: - 只输出 JSON,不要添加解释或多余文本。 - 使用标准双引号,必要时对内部引号进行转义。 - 字段名与结构必须与给定模式一致。 +- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。 +- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。 +- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。 + +### memory_verify字段说明 +当memory_verify为true时,需要输出隐私检测结果: +- **has_privacy**: 布尔值,表示是否检测到隐私信息 +- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"]) +- **summary**: 字符串,简要描述隐私检测结果 + +当memory_verify为false时,memory_verify字段输出null。 + +### memory_verify字段示例 + +**示例1:检测到隐私信息** +```json +"memory_verify": { + "has_privacy": true, + "privacy_types": ["手机号码", "身份证信息"], + "summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码" +} +``` + +**示例2:未检测到隐私信息** +```json +"memory_verify": { + "has_privacy": false, + "privacy_types": [], + "summary": "未检测到隐私信息" +} +``` + +**示例3:memory_verify为false时** +```json +"memory_verify": null +``` 模式参考: -[ - {{ json_schema }} -] \ No newline at end of file +{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index 3f78b137..43e8e100 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -1,23 +1,300 @@ +你将收到一组用户历史记忆原始数据(来源于 Neo4j) 你将收到一条冲突判定对象:{{ data }}。 -任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。 +需要检测冲突对象:{{ statement_databasets }} +以及需要识别的冲突对象为:{{ baseline }} +记忆审核开关:{{ memory_verify }}(取值为 true / false) + +角色: +- 你是数据领域中解决数据冲突的专家 + +任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。 + +数据的结构: + statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容, + 需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间 + +**处理模式**: +- 当memory_verify为false时:仅处理数据冲突 +- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏 + +## 分组处理原则 + +**冲突类型识别与分组**: +1. **日期冲突**: + 1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号), + 1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球) +3. **事实属性冲突**: + 3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是) + 3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述 + 3.3. **身份冲突**:同一实体被赋予不同的类型或角色 +4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别 + +**分组输出要求**: +- 每种冲突类型生成一个独立的reflexion_result对象 +- 同一类型的多个冲突记录归并到一个结果中 +- 不同类型的冲突分别处理,各自生成独立结果 + +## 冲突类型定义 + +### 时间冲突(TIME) +时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。 + +### 事实冲突(FACT) +事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。 +### 混合冲突(HYBRID) +检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录 +{% if memory_verify %} +## 隐私信息处理(memory_verify为true时启用) + +### 隐私信息识别 +需要识别并处理以下类型的隐私信息: + +1. **身份证信息**:包含身份证号码、身份证相关描述 +2. **手机号码**:包含手机号、电话号码等联系方式 +3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息 +4. **银行信息**:包含银行卡号、账户信息、支付信息 +5. **税务信息**:包含税号、纳税信息、发票信息 +6. **贷款信息**:包含贷款记录、信贷信息、借款信息 +7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息 + +### 隐私数据脱敏规则 +对于检测到的隐私信息,按以下规则进行脱敏处理: + +**数字类隐私信息脱敏**: +- 保留前三位和后四位,中间用*代替 +- 示例:手机号13812345678 → 138****5678 +- 示例:身份证110101199001011234 → 110***********1234 +- 示例:银行卡6222021234567890 → 622***********7890 + +**文本类隐私信息脱敏**: +- 社交账号:保留前三后四位字符,中间用*代替 +- 示例:微信号user123456 → use****3456 +- 示例:邮箱zhang.san@example.com → zha****@example.com + +**脱敏处理字段**: +- name字段:如包含隐私信息需脱敏 +- entity1_name字段:如包含隐私信息需脱敏 +- entity2_name字段:如包含隐私信息需脱敏 +- description字段:如包含隐私信息需脱敏 +{% endif %} + +## 工作步骤 + +### 第一步:分析冲突类型匹配 +首先判断输入的冲突数据是否符合baseline要求的类型: + +**类型匹配规则**: +- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突) +- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致) +- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理 + +**类型识别**: +- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等) +- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述 + +**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null) + +### 第二步:筛选并分组冲突数据 +按冲突类型对数据进行分组: + +**分组策略**: +1. **时间冲突组**:筛选涉及用户时间的所有记录 +2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录 +3. **事实冲突组**:筛选涉及同一实体不同属性的记录 +4. **其他冲突组**:其他类型的冲突记录 + +**筛选条件**: +- 只处理与baseline匹配的冲突类型 +- 相同entity1_name但entity2_name不同的记录 +- 相同关系但描述矛盾的记录 +- 时间逻辑不一致的记录 + +### 第三步:冲突解决策略 +** 不可以解决的冲突情况 + 1. 数据被判定为正确的情况下,不可以进行修改 +**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理: + +**智能解决策略**: +1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定 +2. **判断正确答案是否存在**: + - 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00) + - 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改 + - 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息 + +{% if memory_verify %} +**隐私处理集成**: +- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏 +- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏 +- 在change字段中记录隐私脱敏的变更 +{% endif %} + +**具体处理规则**: + +**情况1:正确答案存在于data中** +- 保留正确的记录不变 +- 基于时间关系的冲突: + 需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00) +- 基于事实的关系冲突 +- resolved.resolved_memory只包含被设为失效的错误记录 +- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间) + +**情况2:正确答案不存在于data中** +- 选择最合适的记录进行修改 +- 更新该记录的相关字段: + - description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %} + - name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %} +- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %} +- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]` + +**重要原则**: +- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据 +- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录 +- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值 +{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理 +- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %} +- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空 + +**变更记录格式**: +```json +"change": [ + { + "field": [ + {"字段名1": "修改后的值1"}, + {"字段名2": "修改后的值2"} + ] + } +] +``` + +**类型不匹配处理**: +- 如果冲突类型与baseline不匹配,resolved必须设为null +- reflexion.reason说明类型不匹配的原因 +- reflexion.solution说明无需处理 + +### 第四步:输出解决方案 + +## 输出要求 +**嵌套字段映射**(系统会自动处理): +- `entity2.name` → 自动映射为 `name` +- `entity1.name` → 自动映射为 `name` +- `entity1.description` → 自动映射为 `description` +- `entity2.description` → 自动映射为 `description` + +返回数据格式以json方式输出: +- 必须通过json.loads()的格式支持的形式输出 +- 响应必须是与此确切模式匹配的有效JSON对象 +- 不要在JSON之前或之后包含任何文本 + +JSON格式要求: +1. JSON结构仅使用标准ASCII双引号(") +2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义 +3. 确保所有JSON字符串都正确关闭并以逗号分隔 +4. JSON字符串值中不包括换行符 +5. 不允许输出```json```相关符号 仅输出一个合法 JSON 对象,严格遵循下述结构: + +**输出格式:按冲突类型分组的列表** { - "conflict": 与输入同结构,包含 data 与 conflict_memory, - "reflexion": { "reason": string, "solution": string }, - "resolved": { - "original_memory_id": 被设为失效的记忆 id, - "resolved_memory": 完整的设为失效后的记忆对象 - } + "results": [ + { + "conflict": { + "data": [该冲突类型相关的数据记录], + "conflict": true + }, + "reflexion": { + "reason": "该冲突类型的原因分析", + "solution": "该冲突类型的解决方案" + }, + "resolved": { + "original_memory_id": "被设为失效的记忆id", + "resolved_memory": { + "entity1_name": "实体1名称", + "entity2_name": "实体2名称", + "description": "描述信息", + "statement_id": "陈述ID", + "created_at": "创建时间", + "expired_at": "过期时间", + "relationship_type": "关系类型", + "relationship": {}, + "entity2": {...} + }, + "change": [ + { + "field": [ + {"字段名1": "修改后的值1"}, + {"字段名2": "修改后的值2"} + ] + } + ] + }, + "type": "reflexion_result" + } + ] +} + +**示例:多种冲突类型的输出** +{ + "results": [ + { + "conflict": { + "data": [生日冲突相关的记录], + "conflict": true + }, + "reflexion": { + "reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期", + "solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效" + }, + "resolved": { + "original_memory_id": "df066210883545a08e727ccd8ad4ec77", + "resolved_memory": {...}, + "change": [ + { + "field": [ + {"expired_at": "2025-12-16T12:00:00"} + ] + } + ] + }, + "type": "reflexion_result" + }, + { + "conflict": { + "data": [篮球时间冲突相关的记录], + "conflict": true + }, + "reflexion": { + "reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突", + "solution": "保留最可信的时间记录,将冲突记录设为失效" + }, + "resolved": { + "original_memory_id": "另一个记录ID", + "resolved_memory": {...}, + "change": [ + { + "field": [ + {"description": "使用系统的个人,指代说话者本人,篮球时间为周六"}, + {"entity2_name": "周六"} + ] + } + ] + }, + "type": "reflexion_result" + } + ] } 必须遵守: -- 只输出 JSON,不要添加解释或多余文本。 -- 使用标准双引号,必要时对内部引号进行转义。 -- 字段名与结构必须与给定模式一致。 -- 当 conflict 为 false 时,resolved 必须为 null。 - - 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。 +- 只输出 JSON,不要添加解释或多余文本 +- 使用标准双引号,必要时对内部引号进行转义 +- 字段名与结构必须与给定模式一致 +- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象 +- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中 +- **每个result对象的conflict.data**只包含该冲突类型相关的记录 +- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出 +- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值 +- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null +- 如果与baseline不匹配的冲突类型,不要在results中包含该类型 + 模式参考: -[ - {{ json_schema }} -] +{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index c783e095..818d456a 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -7,36 +7,50 @@ from typing import List, Dict, Any prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) -async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str: +async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any], + baseline: str = "TIME", + memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str: """ - Renders the evaluate prompt using the evaluate.jinja2 template. + Renders the evaluate prompt using the evaluate_optimized.jinja2 template. Args: evaluate_data: The data to evaluate schema: The JSON schema to use for the output. + baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT) + memory_verify: Whether to enable memory verification for privacy detection Returns: Rendered prompt content as string """ template = prompt_env.get_template("evaluate.jinja2") - rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema) - + rendered_prompt = template.render( + evaluate_data=evaluate_data, + json_schema=schema, + baseline=baseline, + memory_verify=memory_verify, + quality_assessment=quality_assessment, + statement_databasets=statement_databasets + ) return rendered_prompt -async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str: +async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False, + statement_databasets: List[str] = []) -> str: """ - Renders the reflexion prompt using the extract_temporal.jinja2 template. + Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Args: data: The data to reflex on. schema: The JSON schema to use for the output. + baseline: The baseline type for conflict resolution. Returns: Rendered prompt content as a string. """ template = prompt_env.get_template("reflexion.jinja2") - rendered_prompt = template.render(data=data, json_schema=schema) + rendered_prompt = template.render(data=data, json_schema=schema, + baseline=baseline,memory_verify=memory_verify, + statement_databasets=statement_databasets) return rendered_prompt diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py index 9f27562c..be43bd8d 100644 --- a/api/app/models/data_config_model.py +++ b/api/app/models/data_config_model.py @@ -1,5 +1,4 @@ import datetime -import uuid from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float from sqlalchemy.dialects.postgresql import UUID from app.db import Base @@ -11,50 +10,53 @@ class DataConfig(Base): # 主键 config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID") - + # 基本信息 config_name = Column(String, nullable=False, comment="配置名称") config_desc = Column(String, nullable=True, comment="配置描述") - + # 组织信息 workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID") group_id = Column(String, nullable=True, comment="组ID") user_id = Column(String, nullable=True, comment="用户ID") apply_id = Column(String, nullable=True, comment="应用ID") - + # 模型选择(从workspace继承) llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") llm = Column(String, nullable=True, comment="LLM模型配置ID") - + # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧") deep_retrieval = Column(Boolean, default=True, comment="深度检索开关") - + # 阈值配置 (0-1 之间的浮点数) t_type_strict = Column(Float, default=0.8, comment="类型严格阈值") t_name_strict = Column(Float, default=0.8, comment="名称严格阈值") t_overall = Column(Float, default=0.8, comment="综合阈值") - + # 状态配置 state = Column(Boolean, default=False, comment="配置使用状态") - + # 分块策略 chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") - + # 剪枝配置 pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝") pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound") pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)") - + # 自我反思配置 enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思") iteration_period = Column(String, default="3", comment="反思迭代周期") reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部") baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实") - + reflection_model_id = Column(String, nullable=True, comment="反思模型ID") + memory_verify = Column(Boolean, default=True, comment="记忆验证") + quality_assessment = Column(Boolean, default=True, comment="质量评估") + # 遗忘引擎配置 statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3") include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文") @@ -62,7 +64,7 @@ class DataConfig(Base): lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") - + # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index a2c02f84..2a9ed8da 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -14,6 +14,7 @@ class EndUser(Base): other_id = Column(String, nullable=True) # Store original user_id other_name = Column(String, default="", nullable=False) other_address = Column(String, default="", nullable=False) + reflection_time = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.datetime.now) updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index ed1a482a..6b281ef1 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -16,48 +16,46 @@ import uuid from app.models.data_config_model import DataConfig from app.schemas.memory_storage_schema import ( ConfigParamsCreate, - ConfigParamsDelete, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, - ConfigKey, ) from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() - +TABLE_NAME = "data_config" class DataConfigRepository: """数据配置Repository - + 提供data_config表的数据访问方法,包括: - SQLAlchemy ORM 数据库操作 - Neo4j Cypher查询常量 """ - + # ==================== Neo4j Cypher 查询常量 ==================== - + # Dialogue count by group SEARCH_FOR_DIALOGUE = """ MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # Chunk count by group SEARCH_FOR_CHUNK = """ MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # Statement count by group SEARCH_FOR_STATEMENT = """ MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # ExtractedEntity count by group SEARCH_FOR_ENTITY = """ MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num """ - + # All counts by label and total SEARCH_FOR_ALL = """ OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count @@ -70,7 +68,7 @@ class DataConfigRepository: UNION ALL OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count """ - + # Extracted entity details within group/app/user SEARCH_FOR_DETIALS = """ MATCH (n:ExtractedEntity) @@ -86,7 +84,7 @@ class DataConfigRepository: n.user_id AS user_id, n.id AS id """ - + # Edges between extracted entities within group/app/user SEARCH_FOR_EDGES = """ MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) @@ -102,7 +100,7 @@ class DataConfigRepository: r.statement_id AS statement_id, r.statement AS statement """ - + # Entity graph within group (source node, edge, target node) SEARCH_FOR_ENTITY_GRAPH = """ MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) @@ -135,22 +133,106 @@ class DataConfigRepository: id: m.id } AS targetNode """ - + # ==================== SQLAlchemy ORM 数据库操作方法 ==================== - + @staticmethod + def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]: + """构建反思配置更新语句(SQLAlchemy text() 命名参数) + + Args: + config_id: 配置ID + **kwargs: 反思配置参数 + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + + Raises: + ValueError: 没有字段需要更新时抛出 + """ + db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") + + key_where = "config_id = :config_id" + set_fields: List[str] = [] + params: Dict = { + "config_id": config_id, + } + + # 反思配置字段映射 + mapping = { + "enable_self_reflexion": "enable_self_reflexion", + "iteration_period": "iteration_period", + "reflexion_range": "reflexion_range", + "baseline": "baseline", + "reflection_model_id": "reflection_model_id", + "memory_verify": "memory_verify", + "quality_assessment": "quality_assessment", + } + + for api_field, db_col in mapping.items(): + if api_field in kwargs and kwargs[api_field] is not None: + set_fields.append(f"{db_col} = :{api_field}") + params[api_field] = kwargs[api_field] + + if not set_fields: + raise ValueError("No fields to update") + + set_fields.append("updated_at = timezone('Asia/Shanghai', now())") + query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" + return query, params + + @staticmethod + def build_select_reflection(config_id: int) -> Tuple[str, Dict]: + """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) + + Args: + config_id: 配置ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") + + query = ( + f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, " + f"reflection_model_id, memory_verify, quality_assessment, user_id " + f"FROM {TABLE_NAME} WHERE config_id = :config_id" + ) + params = {"config_id": config_id} + return query, params + + @staticmethod + def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: + """构建查询所有配置的语句(SQLAlchemy text() 命名参数) + + Args: + workspace_id: 工作空间ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") + + query = ( + f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, " + f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at " + f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC" + ) + params = {"workspace_id": workspace_id} + return query, params + @staticmethod def create(db: Session, params: ConfigParamsCreate) -> DataConfig: """创建数据配置 - + Args: db: 数据库会话 params: 配置参数创建模型 - + Returns: DataConfig: 创建的配置对象 """ db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}") - + try: db_config = DataConfig( config_name=params.config_name, @@ -162,37 +244,37 @@ class DataConfigRepository: ) db.add(db_config) db.flush() # 获取自增ID但不提交事务 - + db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})") return db_config - + except Exception as e: db.rollback() db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}") raise - + @staticmethod def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]: """更新基础配置 - + Args: db: 数据库会话 update: 配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新数据配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段 has_update = False if update.config_name is not None: @@ -201,44 +283,44 @@ class DataConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]: """更新记忆萃取引擎配置 - + Args: db: 数据库会话 update: 萃取配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段映射 field_mapping = { # 模型选择 @@ -268,50 +350,50 @@ class DataConfigRepository: "reflexion_range": "reflexion_range", "baseline": "baseline", } - + has_update = False for api_field, db_field in field_mapping.items(): value = getattr(update, api_field, None) if value is not None: setattr(db_config, db_field, value) has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"萃取配置更新成功: config_id={update.config_id}") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]: """更新遗忘引擎配置 - + Args: db: 数据库会话 update: 遗忘配置更新模型 - + Returns: Optional[DataConfig]: 更新后的配置对象,不存在则返回None - + Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新遗忘配置: config_id={update.config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={update.config_id}") return None - + # 更新字段 has_update = False if update.lambda_time is not None: @@ -323,40 +405,40 @@ class DataConfigRepository: if update.offset is not None: db_config.offset = update.offset has_update = True - + if not has_update: raise ValueError("No fields to update") - + db.commit() db.refresh(db_config) - + db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}") return db_config - + except Exception as e: db.rollback() db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}") raise - + @staticmethod def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]: """获取萃取配置,通过主键查询某条配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[Dict]: 萃取配置字典,不存在则返回None """ db_logger.debug(f"查询萃取配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"萃取配置不存在: config_id={config_id}") return None - + result = { "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, @@ -379,62 +461,62 @@ class DataConfigRepository: "reflexion_range": db_config.reflexion_range, "baseline": db_config.baseline, } - + db_logger.debug(f"萃取配置查询成功: config_id={config_id}") return result - + except Exception as e: db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_forget_config(db: Session, config_id: int) -> Optional[Dict]: """获取遗忘配置,通过主键查询某条配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[Dict]: 遗忘配置字典,不存在则返回None """ db_logger.debug(f"查询遗忘配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"遗忘配置不存在: config_id={config_id}") return None - + result = { "lambda_time": db_config.lambda_time, "lambda_mem": db_config.lambda_mem, "offset": db_config.offset, } - + db_logger.debug(f"遗忘配置查询成功: config_id={config_id}") return result - + except Exception as e: db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]: """根据ID获取数据配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: Optional[DataConfig]: 配置对象,不存在则返回None """ db_logger.debug(f"根据ID查询数据配置: config_id={config_id}") - + try: config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - + if config: db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})") else: @@ -443,60 +525,60 @@ class DataConfigRepository: except Exception as e: db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}") raise - + @staticmethod def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]: """获取所有配置参数 - + Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 - + Returns: List[DataConfig]: 配置列表 """ db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") - + try: query = db.query(DataConfig) - + if workspace_id: query = query.filter(DataConfig.workspace_id == workspace_id) - + configs = query.order_by(desc(DataConfig.updated_at)).all() - + db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") return configs - + except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") raise - + @staticmethod def delete(db: Session, config_id: int) -> bool: """删除数据配置 - + Args: db: 数据库会话 config_id: 配置ID - + Returns: bool: 删除成功返回True,配置不存在返回False """ db_logger.debug(f"删除数据配置: config_id={config_id}") - + try: db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() if not db_config: db_logger.warning(f"数据配置不存在: config_id={config_id}") return False - + db.delete(db_config) db.commit() - + db_logger.info(f"数据配置删除成功: config_id={config_id}") return True - + except Exception as e: db.rollback() db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}") diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 7330a00f..95e2ee03 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -746,3 +746,57 @@ DETACH DELETE losing RETURN count(losing) as deleted """ + +neo4j_statement_part = ''' +MATCH (n:Statement) +WHERE n.group_id = "{}" + AND datetime(n.created_at) >= datetime() - duration('P3D') +RETURN + n.statement as statement_name, + n.id as statement_id, + n.created_at as statement_created_at + +''' +neo4j_statement_all = ''' +MATCH (n:Statement) +WHERE n.group_id = "{}" +RETURN + n.statement as statement_name, + n.id as statement_id + +''' +neo4j_query_part = """ + MATCH (n)-[r]-(m:ExtractedEntity) + WHERE n.group_id = "{}" + AND datetime(n.created_at) >= datetime() - duration('P3D') + WITH DISTINCT m + OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) + RETURN + m.name as entity1_name, + m.description as description, + m.statement_id as statement_id, + m.created_at as created_at, + m.expired_at as expired_at, + CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + rel as relationship, + CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, + other as entity2 + """ +neo4j_query_all = """ + MATCH (n)-[r]-(m:ExtractedEntity) + WHERE n.group_id = "{}" + WITH DISTINCT m + OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) + RETURN + m.name as entity1_name, + m.description as description, + m.statement_id as statement_id, + m.created_at as created_at, + m.expired_at as expired_at, + CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + rel as relationship, + CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name, + other as entity2 + """ + + diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py new file mode 100644 index 00000000..9644224c --- /dev/null +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -0,0 +1,227 @@ +from app.repositories import Neo4jConnector + +neo4j_connector = Neo4jConnector() + +async def update_neo4j_data(neo4j_dict_data, update_databases): + """ + Update Neo4j data based on query criteria and update parameters + + Args: + neo4j_dict_data: find + update_databases: update + """ + try: + # 构建WHERE条件 + where_conditions = [] + params = {} + + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value + + where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" + + # 构建SET条件 + set_conditions = [] + for key, value in update_databases.items(): + if value is not None: + param_name = f"update_{key}" + set_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value + + set_clause = ", ".join(set_conditions) + + if not set_clause: + print("警告: 没有需要更新的字段") + return False + + # 构建Cypher查询 + cypher_query = f""" + MATCH (e:ExtractedEntity) + WHERE {where_clause} + SET {set_clause} + RETURN count(e) as updated_count, collect(e.name) as updated_names + """ + + print(f"\n执行Cypher查询: {cypher_query}") + print(f"参数: {params}") + + # 执行更新 + result = await neo4j_connector.execute_query(cypher_query, **params) + + if result: + updated_count = result[0].get('updated_count', 0) + updated_names = result[0].get('updated_names', []) + print(f"成功更新 {updated_count} 个节点") + if updated_names: + print(f"更新的实体名称: {updated_names}") + return updated_count > 0 + else: + return False + + except Exception as e: + print(f"更新过程中出现错误: {e}") + import traceback + traceback.print_exc() + return False + + +def map_field_names(data_dict): + mapped_dict = {} + has_name_field = False + + # 第一遍:检查是否有name相关字段 + for key, value in data_dict.items(): + if key in ['name', 'entity2.name', 'entity1.name']: + has_name_field = True + break + + print(f"字段检查: has_name_field = {has_name_field}") + + # 第二遍:根据规则映射和过滤字段 + for key, value in data_dict.items(): + if key == 'entity2.name' or key == 'entity2_name': + # 将 entity2.name 映射为 name + mapped_dict['name'] = value + print(f"字段名映射: {key} -> name") + elif key == 'entity1.name' or key == 'entity1_name': + # 将 entity1.name 映射为 name + mapped_dict['name'] = value + print(f"字段名映射: {key} -> name") + elif key == 'entity1.description': + # 将 entity1.description 映射为 description + mapped_dict['description'] = value + print(f"字段名映射: {key} -> description") + elif key == 'entity2.description': + # 将 entity2.description 映射为 description + mapped_dict['description'] = value + print(f"字段名映射: {key} -> description") + elif key == 'relationship_type': + # 跳过relationship_type字段 + print(f"字段过滤: 跳过不需要的字段 '{key}'") + continue + elif key == 'entity1_name': + if has_name_field: + # 如果有name字段,跳过entity1_name + print(f"字段过滤: 由于存在name字段,跳过 '{key}'") + continue + else: + # 如果没有name字段,保留entity1_name + mapped_dict[key] = value + print(f"字段保留: {key}") + elif key == 'entity2_name': + if has_name_field: + # 如果有name字段,跳过entity2_name + print(f"字段过滤: 由于存在name字段,跳过 '{key}'") + continue + else: + # 即使没有name字段,也不使用entity2_name(根据需求) + print(f"字段过滤: 跳过不推荐的字段 '{key}'") + continue + elif '.' not in key: + # 不包含点号的其他字段直接保留 + mapped_dict[key] = value + else: + # 其他包含点号的字段跳过并警告 + print(f"警告: 跳过不支持的嵌套字段 '{key}'") + + print(f"字段映射结果: {mapped_dict}") + return mapped_dict +async def neo4j_data(solved_data): + """ + Process the resolved data and update the Neo4j database + Args: + Solved_data: Solution Data List + Returns: + Int: Number of successfully updated records + """ + success_count = 0 + + for i in solved_data: + neo4j_dict_data = {} + update_databases = {} + results = i['results'] + for data in results: + resolved = data.get('resolved') + if not resolved: + print("跳过:resolved为None") + continue + + try: + change_list = resolved.get('change', []) + except (AttributeError, TypeError): + change_list = [] + + if change_list == []: + print("跳过:change_list为空") + continue + + if change_list and len(change_list) > 0: + change = change_list[0] + print(f"change: {change}") + field_data = change.get('field', []) + print(f"field_data: {field_data}") + print(f"field_data type: {type(field_data)}") + + # 字段名映射和过滤函数 + + + # 处理field数据,可能是字典或列表 + if isinstance(field_data, dict): + # 如果是字典,映射字段名后更新 + mapped_data = map_field_names(field_data) + update_databases.update(mapped_data) + elif isinstance(field_data, list): + # 如果是列表,遍历每个字典并更新 + for field_item in field_data: + if isinstance(field_item, dict): + mapped_item = map_field_names(field_item) + update_databases.update(mapped_item) + else: + print(f"警告: field_item不是字典: {field_item}") + else: + print(f"警告: field_data类型不支持: {type(field_data)}") + + if 'entity1_name' in data: + data['name'] = data.pop('entity1_name') + if 'entity2_name' in data: + data.pop('entity2_name', None) + + resolved_memory = resolved.get('resolved_memory', {}) + + entity2 = None + if isinstance(resolved_memory, dict): + entity2 = resolved_memory.get('entity2') + + if entity2 and isinstance(entity2, dict) and len(entity2) >= 5: + stat_id = resolved.get('original_memory_id') + # 安全地获取description + statement_id = None + if isinstance(resolved_memory, dict): + statement_id = resolved_memory.get('statement_id') + + # 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id + if statement_id and 'id' not in neo4j_dict_data: + neo4j_dict_data['id'] = stat_id + neo4j_dict_data['statement_id'] = statement_id + else: + # 处理original_memory_id,它可能是字符串或字典 + try: + for key, value in resolved_memory.items(): + if key == 'statement_id': + neo4j_dict_data['statement_id'] = value + if key == 'description': + neo4j_dict_data['description'] = value + except AttributeError: + neo4j_dict_data=[] + + print(neo4j_dict_data) + print(update_databases) + if neo4j_dict_data!=[]: + await update_neo4j_data(neo4j_dict_data, update_databases) + success_count += 1 + + return success_count + diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 30dafddd..74fc4a14 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -13,5 +13,6 @@ class EndUser(BaseModel): other_id: Optional[str] = Field(description="第三方ID", default=None) other_name: Optional[str] = Field(description="其他名称", default="") other_address: Optional[str] = Field(description="其他地址", default="") + reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now) created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) diff --git a/api/app/schemas/memory_reflection_schemas.py b/api/app/schemas/memory_reflection_schemas.py new file mode 100644 index 00000000..9eb11c6c --- /dev/null +++ b/api/app/schemas/memory_reflection_schemas.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field +from typing import Optional +from enum import Enum + + +class OptimizationStrategy(str, Enum): + """优化策略枚举""" + SPEED_FIRST = "speed_first" + ACCURACY_FIRST = "accuracy_first" + BALANCED = "balanced" + + +class Memory_Reflection(BaseModel): + config_id: Optional[int] = None + reflectionenabled: bool + reflection_period_in_hours: str + reflexion_range: str + baseline: str + reflection_model_id: str + memory_verify: bool + quality_assessment: bool + + # 新增快速引擎优化参数 + optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED + use_fast_model: Optional[bool] = True + enable_caching: Optional[bool] = True + enable_streaming: Optional[bool] = True + batch_size: Optional[int] = Field(default=3, ge=1, le=10) + max_concurrent: Optional[int] = Field(default=5, ge=1, le=20) + + class Config: + use_enum_values = True + + +class FastReflectionRequest(BaseModel): + """快速反思请求模型""" + reflection: Memory_Reflection + host_id: Optional[str] = "88a459f5_text02" + optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED + + class Config: + use_enum_values = True + + +class ReflectionBenchmarkRequest(BaseModel): + """反思基准测试请求模型""" + reflection: Memory_Reflection + host_id: Optional[str] = "88a459f5_text02" + iterations: Optional[int] = Field(default=3, ge=1, le=10) + + class Config: + use_enum_values = True + + diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 66b2e45f..ab6b0512 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -2,7 +2,7 @@ 所有的内容是放错误地方了,应该放在models """ -from typing import Any, Optional, List, Dict, Literal +from typing import Any, Optional, List, Dict, Literal, Union import time import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator @@ -28,25 +28,48 @@ class Write_UserInput(BaseModel): # ============================================================================ class BaseDataSchema(BaseModel): """Base schema for the data""" - id: str = Field(..., description="The unique identifier for the data entry.") - statement: str = Field(..., description="The statement text.") - group_id: str = Field(..., description="The group identifier.") - chunk_id: str = Field(..., description="The chunk identifier.") + # 保持原有必需字段为可选,以兼容不同数据源 + id: Optional[str] = Field(None, description="The unique identifier for the data entry.") + statement: Optional[str] = Field(None, description="The statement text.") + group_id: Optional[str] = Field(None, description="The group identifier.") + chunk_id: Optional[str] = Field(None, description="The chunk identifier.") created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.") invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.") entity_ids: List[str] = Field([], description="The list of entity identifiers.") + description: Optional[str] = Field(None, description="The description of the data entry.") + + # 新增字段以匹配实际输入数据 + entity1_name: str = Field(..., description="The first entity name.") + entity2_name: Optional[str] = Field(None, description="The second entity name.") + statement_id: str = Field(..., description="The statement identifier.") + relationship_type: str = Field(..., description="The relationship type.") + relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.") + entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.") + + +class QualityAssessmentSchema(BaseModel): + """Schema for memory quality assessment results.""" + score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).") + summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.") + + +class MemoryVerifySchema(BaseModel): + """Schema for memory privacy verification results.""" + has_privacy: bool = Field(..., description="Whether privacy information was detected.") + privacy_types: List[str] = Field([], description="List of detected privacy information types.") + summary: str = Field(..., description="Brief summary of privacy detection results.") class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data.") + data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") - @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -61,7 +84,6 @@ class ConflictSchema(BaseModel): conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") @model_validator(mode="before") - @classmethod def _normalize_data(cls, v): if isinstance(v, dict): d = v.get("data") @@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel): solution: str = Field(..., description="The solution for the reflexion.") +class ChangeRecordSchema(BaseModel): + """Schema for individual change records""" + field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.") + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") - resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.") + # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") +class SingleReflexionResultSchema(BaseModel): + """Schema for a single reflexion result item.""" + conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.") + reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.") + resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") + type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): - """Schema for the reflexion result data in the reflexion_data.json file.""" - # 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射 - conflict: ConflictResultSchema = Field(..., description="The conflict result data.") - reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.") - resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.") + """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" + results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") - @classmethod def _normalize_resolved(cls, v): if isinstance(v, dict): conflict = v.get("conflict") diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py new file mode 100644 index 00000000..0f8fb569 --- /dev/null +++ b/api/app/services/memory_reflection_service.py @@ -0,0 +1,397 @@ +""" +记忆反思服务 +处理反思引擎的调用和执行 +""" +from datetime import datetime +from typing import Dict, Any, Optional, Set + +from fastapi import Depends +from sqlalchemy.orm import Session +from sqlalchemy import text + +from app.db import get_db +from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.reflection_engine import ReflectionConfig, ReflectionEngine +from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionRange, ReflectionBaseline +from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.models.app_model import App +from app.models.app_release_model import AppRelease +from app.models.end_user_model import EndUser + +api_logger = get_api_logger() + + +class WorkspaceAppService: + """Workplace Application Service Class """ + + def __init__(self, db: Session): + self.db = db + + def get_workspace_apps_detailed(self, workspace_id: str) -> Dict[str, Any]: + """ + Get detailed information of all applications in the workspace + + Args: + Workspace_id: Workspace ID + + Returns: + Dictionary containing detailed application information + """ + apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() + app_ids = [str(app.id) for app in apps] + + apps_detailed_info = [] + + for app in apps: + app_info = self._build_app_info(app) + self._process_app_releases(app, app_info) + self._process_end_users(app, app_info) + apps_detailed_info.append(app_info) + + return { + "status": "成功", + "message": f"成功查询到 {len(app_ids)} 个应用及其详细信息", + "workspace_id": str(workspace_id), + "apps_count": len(app_ids), + "app_ids": app_ids, + "apps_detailed_info": apps_detailed_info + } + + def _build_app_info(self, app: App) -> Dict[str, Any]: + """base_infomation""" + return { + "id": str(app.id), + "name": app.name, + "description": app.description, + "type": app.type, + "status": app.status, + "visibility": app.visibility, + "created_at": app.created_at.isoformat() if app.created_at else None, + "updated_at": app.updated_at.isoformat() if app.updated_at else None, + "releases": [], + "data_configs": [], + "end_users": [] + } + + def _process_app_releases(self, app: App, app_info: Dict[str, Any]) -> None: + """Process the release version and configuration information of the application""" + app_releases = self.db.query(AppRelease).filter(AppRelease.app_id == app.id).all() + + if not app_releases: + return + + processed_configs: Set[str] = set() + + for release in app_releases: + memory_content = self._extract_memory_content(release.config) + + + if memory_content and memory_content in processed_configs: + continue + + release_info = { + "app_id": str(release.app_id), + "config": memory_content + } + + + if memory_content: + processed_configs.add(memory_content) + data_config_info = self._get_data_config(memory_content) + + if data_config_info: + if not any(dc["config_id"] == data_config_info["config_id"] for dc in app_info["data_configs"]): + app_info["data_configs"].append(data_config_info) + + app_info["releases"].append(release_info) + + def _extract_memory_content(self, config: Any) -> str: + """Extract memory_comtent from config""" + if not config or not isinstance(config, dict): + return None + + memory_obj = config.get('memory') + if memory_obj and isinstance(memory_obj, dict): + return memory_obj.get('memory_content') + + return None + + def _get_data_config(self, memory_content: str) -> Dict[str, Any]: + """Retrieve data_comfig information based on memory_comtent""" + try: + data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) + data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() + if data_config_result is None: + return None + + if data_config_result: + return { + "config_id": data_config_result.config_id, + "enable_self_reflexion": data_config_result.enable_self_reflexion, + "iteration_period": data_config_result.iteration_period, + "reflexion_range": data_config_result.reflexion_range, + "baseline": data_config_result.baseline, + "reflection_model_id": data_config_result.reflection_model_id, + "memory_verify": data_config_result.memory_verify, + "quality_assessment": data_config_result.quality_assessment, + "user_id": data_config_result.user_id + } + except Exception as e: + api_logger.warning(f"查询data_config失败,memory_content: {memory_content}, 错误: {str(e)}") + + return None + + def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None: + """Processing end-user information for applications""" + end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all() + + for end_user in end_users: + end_user_info = { + "id": str(end_user.id), + "app_id": str(end_user.app_id) + } + app_info["end_users"].append(end_user_info) + + def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]: + """ + Read the reflection time of end users + + Args: + End_user_id: End User ID + + Returns: + Reflection time or None + """ + try: + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first() + if end_user: + return end_user.reflection_time + return None + except Exception as e: + api_logger.error(f"读取用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}") + return None + + def update_end_user_reflection_time(self, end_user_id: str) -> bool: + """ + Update the reflection time of end users to the current time + + Args: + End_user_id: End User ID + + Returns: + Is the update successful + """ + try: + from datetime import datetime + + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first() + if end_user: + end_user.reflection_time = datetime.now() + self.db.commit() + api_logger.info(f"成功更新用户反思时间,end_user_id: {end_user_id}") + return True + else: + api_logger.warning(f"未找到用户,end_user_id: {end_user_id}") + return False + except Exception as e: + api_logger.error(f"更新用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}") + self.db.rollback() + return False + + +class MemoryReflectionService: + """Memory reflection service category""" + + def __init__(self,db: Session = Depends(get_db)): + self.db=db + + + async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + """ + Starting Reflection from Configuration Data + + Args: + config_data: Configure data dictionary, including reflective configuration information + end_user_id: end_user_id + + Returns: + Reflect on the execution results + """ + try: + config_id = config_data.get("config_id") + api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}") + + + if not config_data.get("enable_self_reflexion", False): + return { + "status": "跳过", + "message": "反思引擎未启用", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + + config_data_id=config_data['config_id'] + reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id) + if reflection_config is not None and reflection_config['enable_self_reflexion']: + reflection_config= self._create_reflection_config_from_data(reflection_config) + iteration_period=reflection_config.iteration_period + workspace_service = WorkspaceAppService(self.db) + current_reflection_time = workspace_service.get_end_user_reflection_time(end_user_id) + + reflection_time = datetime.fromisoformat(str(current_reflection_time)) + + current_time = datetime.now() + time_diff = current_time - reflection_time + hours_diff = int(time_diff.total_seconds() / 3600) + if iteration_period==hours_diff or current_reflection_time is None: + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") + # 3. 执行反思引擎 + reflection_results = await self._execute_reflection_engine( + reflection_config, end_user_id + ) + # 更新反思时间为当前时间 + update_success = workspace_service.update_end_user_reflection_time(end_user_id) + if update_success: + api_logger.info(f"成功更新用户 {end_user_id} 的反思时间") + else: + api_logger.error(f"更新用户 {end_user_id} 的反思时间失败") + + return { + "status": "完成", + "message": "反思引擎执行完成", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": reflection_results + } + else: + return { + "status": "等待中..", + "message": "反思引擎未开始执行执", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": '' + } + + except Exception as e: + config_id = config_data.get("config_id", "unknown") + api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") + return { + "status": "错误", + "message": f"启动反思失败: {str(e)}", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig: + """Create reflective configuration objects from configuration data""" + + reflexion_range_value = config_data.get("reflexion_range") + if reflexion_range_value is None or reflexion_range_value == "": + reflexion_range_value = "partial" + reflexion_range = ReflectionRange(reflexion_range_value) + + baseline_value = config_data.get("baseline") + if baseline_value is None or baseline_value == "": + baseline_value = "TIME" + baseline = ReflectionBaseline(baseline_value) + + # iteration_period = + iteration_period = config_data.get("iteration_period", 24) + if isinstance(iteration_period, str): + try: + iteration_period = int(iteration_period) + except (ValueError, TypeError): + iteration_period = 24 # 默认24小时 + + return ReflectionConfig( + enabled=config_data.get("enable_self_reflexion", False), + iteration_period=str(iteration_period), # ReflectionConfig期望字符串 + reflexion_range=reflexion_range, + baseline=baseline, + memory_verify=config_data.get("memory_verify", False), + quality_assessment=config_data.get("quality_assessment", False), + model_id=config_data.get("reflection_model_id", "") + ) + + async def _execute_reflection_engine( + self, + reflection_config: ReflectionConfig, + user_id: str + ) -> Dict[str, Any]: + """Execute Reflection Engine""" + try: + # 创建Neo4j连接器 + connector = Neo4jConnector() + + # 创建反思引擎 + engine = ReflectionEngine( + config=reflection_config, + neo4j_connector=connector, + llm_client=reflection_config.model_id + ) + + # 执行反思 + reflection_result = await engine.execute_reflection(user_id) + + return { + "success": reflection_result.success, + "message": reflection_result.message, + "conflicts_found": reflection_result.conflicts_found, + "conflicts_resolved": reflection_result.conflicts_resolved, + "memories_updated": reflection_result.memories_updated, + "execution_time": reflection_result.execution_time, + "details": reflection_result.details + } + + except Exception as e: + api_logger.error(f"反思引擎执行失败: {str(e)}") + return { + "success": False, + "message": f"反思引擎执行失败: {str(e)}", + "conflicts_found": 0, + "conflicts_resolved": 0, + "memories_updated": 0, + "execution_time": 0.0 + } + + +class Memory_Reflection_Service: + """Memory Reflection Service - Used for calling the/reflection interface""" + + def __init__(self, db: Session): + self.db = db + self.reflection_service = MemoryReflectionService(db) + + async def start_reflection(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + """ + Activate the reflection function + + Args: + config_data: 配置数据,格式如下: + { + "config_id": 26, + "enable_self_reflexion": true, + "iteration_period": "6", + "reflexion_range": "partial", + "baseline": "TIME", + "reflection_model_id": "ea405fa6-c387-4d78-80ab-826d692301b3", + "memory_verify": true, + "quality_assessment": false, + "user_id": null + } + end_user_id: end_user_id,example "12a8b235-6eb1-4481-a53c-b77933b5c949" + + Returns: + """ + api_logger.info(f"Memory_Reflection_Service启动反思,config_id: {config_data.get('config_id')}, end_user_id: {end_user_id}") + + # 调用核心反思服务 + result = await self.reflection_service.start_reflection_from_data(config_data, end_user_id) + + return result \ No newline at end of file diff --git a/api/app/tasks.py b/api/app/tasks.py index 2d461cd3..39758275 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -295,26 +295,6 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage } -def reflection_engine() -> None: - """Empty function placeholder for timed background reflection. - - Intentionally left blank; replace with real reflection logic later. - """ - from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - import asyncio - - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) - - -@celery_app.task(name="app.core.memory.agent.reflection.timer") -def reflection_timer_task() -> None: - """Periodic Celery task that invokes reflection_engine. - - Raises an exception on failure. - """ - reflection_engine() - @celery_app.task(name="app.core.memory.agent.health.check_read_service") def check_read_service_task() -> Dict[str, str]: @@ -464,4 +444,147 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "error": str(e), "workspace_id": workspace_id, "elapsed_time": elapsed_time, + } + + +@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) +def workspace_reflection_task(self) -> Dict[str, Any]: + """定时任务:每30秒运行工作空间反思功能 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService + from app.models.workspace_model import Workspace + from app.core.logging_config import get_api_logger + + api_logger = get_api_logger() + db = next(get_db()) + + try: + # 获取所有工作空间 + workspaces = db.query(Workspace).all() + + if not workspaces: + return { + "status": "SUCCESS", + "message": "没有找到工作空间", + "workspace_count": 0, + "reflection_results": [] + } + + all_reflection_results = [] + + # 遍历每个工作空间 + for workspace in workspaces: + workspace_id = workspace.id + api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + + try: + reflection_service = MemoryReflectionService(db) + + # 使用服务类处理复杂查询逻辑 + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(str(workspace_id)) + + workspace_reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + workspace_reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "reflection_count": len(workspace_reflection_results), + "reflection_results": workspace_reflection_results + }) + + api_logger.info( + f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") + + except Exception as e: + api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "reflection_count": 0, + "reflection_results": [] + }) + + total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) + + return { + "status": "SUCCESS", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", + "workspace_count": len(workspaces), + "total_reflections": total_reflections, + "workspace_results": all_reflection_results + } + + except Exception as e: + api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "workspace_count": 0, + "reflection_results": [] + } + finally: + db.close() + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id } \ No newline at end of file diff --git a/api/check_code.py b/api/check_code.py new file mode 100755 index 00000000..e4634d91 --- /dev/null +++ b/api/check_code.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +代码质量检查脚本 +自动检查代码中的导入错误、未使用变量、语法问题等 + +用法: + python check_code.py # 检查整个 app/ 目录 + python check_code.py file1.py file2.py # 检查指定文件 +""" + +import subprocess +import sys +from pathlib import Path + + +def run_command(cmd: list[str], description: str) -> tuple[bool, str]: + """运行命令并返回结果""" + print(f"\n{'=' * 60}") + print(f"🔍 {description}") + print(f"{'=' * 60}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + + output = result.stdout + result.stderr + success = result.returncode == 0 + + if success: + print(f"✅ {description} - 通过") + else: + print(f"❌ {description} - 发现问题") + if output: + print(output[:2000]) # 只显示前2000字符 + + return success, output + + except Exception as e: + print(f"❌ 执行失败: {e}") + return False, str(e) + + +def main(): + """主函数""" + # 获取命令行参数中的文件列表 + target_files = sys.argv[1:] if len(sys.argv) > 1 else None + + if target_files: + # 检查指定文件 + print(f"🚀 开始代码质量检查 (指定文件: {len(target_files)} 个)...") + target_paths = target_files + ruff_target = target_files + py_compile_files = [f for f in target_files if f.endswith('.py')] + else: + # 检查整个 app/ 目录 + print("🚀 开始代码质量检查 (整个 app/ 目录)...") + target_paths = ["app/"] + ruff_target = ["app/"] + py_compile_files = list(Path("app").rglob("*.py")) + + checks = [ + { + "cmd": ["ruff", "check"] + ruff_target + ["--output-format=concise"], + "description": "Ruff 代码检查 (导入、语法、风格)", + "auto_fix": ["ruff", "check"] + ruff_target + ["--fix", "--unsafe-fixes"], + }, + { + "cmd": ["python", "-m", "py_compile"] + [str(f) for f in py_compile_files], + "description": "Python 语法检查", + "auto_fix": None, + }, + ] + + results = [] + for check in checks: + success, output = run_command(check["cmd"], check["description"]) + results.append( + {"name": check["description"], "success": success, "output": output, "auto_fix": check.get("auto_fix")} + ) + + # 汇总报告 + print(f"\n{'=' * 60}") + print("📊 检查汇总") + print(f"{'=' * 60}") + + all_passed = True + for result in results: + status = "✅ 通过" if result["success"] else "❌ 失败" + print(f"{status} - {result['name']}") + if not result["success"]: + all_passed = False + if result["auto_fix"]: + print(f" 💡 可以运行自动修复: {' '.join(result['auto_fix'])}") + + if all_passed: + print("\n🎉 所有检查通过!") + return 0 + else: + print("\n⚠️ 发现问题,请查看上面的详细信息") + print("\n💡 快速修复命令:") + if target_files: + print(f" ruff check {' '.join(target_files)} --fix --unsafe-fixes") + else: + print(" ruff check app/ --fix --unsafe-fixes") + return 1 + + +if __name__ == "__main__": + sys.exit(main())