Merge #9 into develop from fix/memory_reflection
新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) * fix/memory_reflection: (24 commits squashed) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(功能配置接口+反思celery后台检测反思的迭代周期) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 新增反思功能(检测代码/规范化程序) - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 - 反思优化 Signed-off-by: aliyun8644380055 <accounts_68c0f5d519f260d93ee2997e@mail.teambition.com> Commented-by: aliyun8644380055 <accounts_68c0f5d519f260d93ee2997e@mail.teambition.com> Commented-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/9
This commit is contained in:
@@ -83,17 +83,18 @@ celery_app.autodiscover_tasks(['app'])
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-reflection-engine": {
|
||||
"task": "app.core.memory.agent.reflection.timer",
|
||||
"schedule": reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"check-read-service": {
|
||||
"task": "app.core.memory.agent.health.check_read_service",
|
||||
"schedule": health_schedule,
|
||||
|
||||
# "check-read-service": {
|
||||
# "task": "app.core.memory.agent.health.check_read_service",
|
||||
# "schedule": health_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ from . import (
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
@@ -62,6 +63,7 @@ manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(tool_execution_controller.router)
|
||||
|
||||
|
||||
200
api/app/controllers/memory_reflection_controller.py
Normal file
200
api/app/controllers/memory_reflection_controller.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
|
||||
from app.dependencies import get_current_user
|
||||
from app.db import get_db
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection/save")
|
||||
async def save_reflection_config(
|
||||
request: Memory_Reflection,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
|
||||
|
||||
|
||||
try:
|
||||
config_id = request.config_id
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
update_params = {
|
||||
"enable_self_reflexion": request.reflectionenabled,
|
||||
"iteration_period": request.reflection_period_in_hours,
|
||||
"reflexion_range": request.reflexion_range,
|
||||
"baseline": request.baseline,
|
||||
"reflection_model_id": request.reflection_model_id,
|
||||
"memory_verify": request.memory_verify,
|
||||
"quality_assessment": request.quality_assessment,
|
||||
}
|
||||
|
||||
|
||||
|
||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
||||
|
||||
result = db.execute(text(query), params)
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 查询更新后的配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"status": "成功",
|
||||
"message": "反思配置已保存",
|
||||
"config_id": config_id,
|
||||
"database_record": {
|
||||
"config_id": result.config_id,
|
||||
"enable_self_reflexion": result.enable_self_reflexion,
|
||||
"iteration_period": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as ve:
|
||||
api_logger.error(f"参数错误: {str(ve)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"参数错误: {str(ve)}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"反思配置保存失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"反思配置保存失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection")
|
||||
async def start_workspace_reflection(
|
||||
request: dict,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "完成",
|
||||
"message": f"成功处理 {len(reflection_results)} 个反思任务",
|
||||
"workspace_id": str(workspace_id),
|
||||
"reflection_count": len(reflection_results),
|
||||
"reflection_results": reflection_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"启动workspace反思失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"启动workspace反思失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/reflection/run")
|
||||
async def reflection_run(
|
||||
reflection: Memory_Reflection,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
config = ReflectionConfig(
|
||||
enabled=reflection.reflectionenabled,
|
||||
iteration_period=reflection.reflection_period_in_hours,
|
||||
reflexion_range=reflection.reflexion_range,
|
||||
baseline=reflection.baseline,
|
||||
output_example='',
|
||||
memory_verify=reflection.memory_verify,
|
||||
quality_assessment=reflection.quality_assessment,
|
||||
violation_handling_strategy="block",
|
||||
model_id=reflection.reflection_model_id
|
||||
)
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=reflection.reflection_model_id # 传入 model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
return result
|
||||
@@ -148,6 +148,7 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
{
|
||||
"memory_verify": {
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
{
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"entity_idx": 0,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Person",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "用户",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3d3896797b334572a80d57590026063d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用于个人身份识别的数据",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Information",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份信息",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用户的银行卡号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Numeric",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "6222023847595898",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["上海办"],
|
||||
"connect_strength": "Strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "位于上海的工作办公场所",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "上海办公室",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "北京",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Identifier",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份证号",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,20 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -33,14 +36,14 @@ else:
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
RETRIEVAL = "retrieval" # 从检索结果中反思
|
||||
DATABASE = "database" # 从整个数据库中反思
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
|
||||
|
||||
@@ -48,9 +51,16 @@ class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
concurrency: int = Field(default=5, description="并发数量")
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -75,16 +85,16 @@ class ReflectionEngine:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
@@ -109,7 +119,7 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(config.concurrency)
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
self._lazy_init_done = False
|
||||
@@ -127,11 +137,21 @@ class ReflectionEngine:
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
model_id = self.llm_client
|
||||
self.llm_client = get_llm_client(model_id)
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
@@ -154,13 +174,11 @@ class ReflectionEngine:
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult:
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
"""
|
||||
@@ -176,9 +194,10 @@ class ReflectionEngine:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
reflexion_data = await self._get_reflexion_data(host_id)
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
@@ -187,22 +206,21 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data)
|
||||
if not conflict_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
message="无冲突,无需反思",
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
conflicts_found = len(conflict_data)
|
||||
logging.info(f"发现 {conflicts_found} 个冲突")
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data)
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -210,6 +228,9 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
print(100 * '*')
|
||||
print(solved_data)
|
||||
print(100 * '*')
|
||||
|
||||
conflicts_resolved = len(solved_data)
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
@@ -230,7 +251,8 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
conflicts_resolved=conflicts_resolved,
|
||||
memories_updated=memories_updated,
|
||||
execution_time=execution_time
|
||||
execution_time=execution_time,
|
||||
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,6 +263,79 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
result_data = {}
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
print(item)
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
message="反思失败,未解决冲突",
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
execution_time = time.time() - start_time
|
||||
return {"status": "SUCCESS", "message": "反思试运行", "data": result_data, "time": execution_time}
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
|
||||
return source_data, databasets
|
||||
|
||||
except Exception as e:
|
||||
return [], []
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
@@ -253,17 +348,28 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
"""
|
||||
if self.config.reflexion_range == ReflectionRange.RETRIEVAL:
|
||||
# 从检索结果中获取数据
|
||||
return await self.get_data_func(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.DATABASE:
|
||||
# 从整个数据库中获取数据(待实现)
|
||||
logging.warning("从数据库获取反思数据功能尚未实现")
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {self.config.reflexion_range}")
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any]) -> List[Any]:
|
||||
|
||||
|
||||
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
||||
neo4j_query = neo4j_query_part.format(host_id)
|
||||
neo4j_statement = neo4j_statement_part.format(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.ALL:
|
||||
neo4j_query = neo4j_query_all.format(host_id)
|
||||
neo4j_statement = neo4j_statement_all.format(host_id)
|
||||
try:
|
||||
result = await self.neo4j_connector.execute_query(neo4j_query)
|
||||
result_statement = await self.neo4j_connector.execute_query(neo4j_statement)
|
||||
neo4j_databasets = await self.get_data_func(result)
|
||||
neo4j_state = await self.get_data_statement(result_statement)
|
||||
return neo4j_databasets, neo4j_state
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Neo4j查询失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
@@ -278,14 +384,28 @@ class ReflectionEngine:
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema
|
||||
self.conflict_schema,
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -316,7 +436,7 @@ class ReflectionEngine:
|
||||
logging.error(f"冲突检测失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]:
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
@@ -332,6 +452,8 @@ class ReflectionEngine:
|
||||
return []
|
||||
|
||||
logging.info("====== 冲突解决开始 ======")
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
@@ -341,7 +463,10 @@ class ReflectionEngine:
|
||||
# 渲染反思提示词
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema
|
||||
self.reflexion_schema,
|
||||
baseline,
|
||||
memory_verify,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -381,8 +506,8 @@ class ReflectionEngine:
|
||||
return solved
|
||||
|
||||
async def _apply_reflection_results(
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
@@ -395,57 +520,7 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
if not solved_data:
|
||||
logging.warning("无解决方案数据,跳过更新")
|
||||
return 0
|
||||
|
||||
logging.info("====== 记忆更新开始 ======")
|
||||
|
||||
success_count = 0
|
||||
|
||||
async def _update_one(item: Dict[str, Any]) -> bool:
|
||||
"""更新单条记忆"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
|
||||
# 提取更新参数
|
||||
resolved = item.get("resolved", {})
|
||||
resolved_mem = resolved.get("resolved_memory", {})
|
||||
group_id = resolved_mem.get("group_id")
|
||||
memory_id = resolved_mem.get("id")
|
||||
new_invalid_at = resolved_mem.get("invalid_at")
|
||||
|
||||
if not all([group_id, memory_id, new_invalid_at]):
|
||||
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
|
||||
return False
|
||||
|
||||
# 执行更新
|
||||
await self.neo4j_connector.execute_query(
|
||||
self.update_query,
|
||||
group_id=group_id,
|
||||
id=memory_id,
|
||||
new_invalid_at=new_invalid_at,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"更新单条记忆失败: {e}")
|
||||
return False
|
||||
|
||||
# 并发执行所有更新任务
|
||||
tasks = [
|
||||
_update_one(item)
|
||||
for item in solved_data
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
success_count = sum(1 for r in results if r)
|
||||
|
||||
logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆")
|
||||
|
||||
success_count = await neo4j_data(solved_data)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -456,6 +531,7 @@ class ReflectionEngine:
|
||||
label: 数据标签
|
||||
data: 要记录的数据
|
||||
"""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
||||
@@ -470,9 +546,9 @@ class ReflectionEngine:
|
||||
|
||||
# 基于时间的反思方法
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
@@ -494,8 +570,8 @@ class ReflectionEngine:
|
||||
|
||||
# 基于事实的反思方法
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
@@ -515,8 +591,8 @@ class ReflectionEngine:
|
||||
|
||||
# 综合反思方法
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
@@ -553,33 +629,3 @@ class ReflectionEngine:
|
||||
else:
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
# 便捷函数:创建默认配置的反思引擎
|
||||
def create_reflection_engine(
|
||||
enabled: bool = False,
|
||||
iteration_period: str = "3",
|
||||
reflexion_range: str = "retrieval",
|
||||
baseline: str = "TIME",
|
||||
concurrency: int = 5
|
||||
) -> ReflectionEngine:
|
||||
"""
|
||||
创建反思引擎实例
|
||||
|
||||
Args:
|
||||
enabled: 是否启用反思
|
||||
iteration_period: 反思周期
|
||||
reflexion_range: 反思范围
|
||||
baseline: 反思基线
|
||||
concurrency: 并发数量
|
||||
|
||||
Returns:
|
||||
ReflectionEngine: 反思引擎实例
|
||||
"""
|
||||
config = ReflectionConfig(
|
||||
enabled=enabled,
|
||||
iteration_period=iteration_period,
|
||||
reflexion_range=reflexion_range,
|
||||
baseline=baseline,
|
||||
concurrency=concurrency
|
||||
)
|
||||
return ReflectionEngine(config)
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.schemas.memory_storage_schema import BaseDataSchema
|
||||
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
return results
|
||||
|
||||
|
||||
async def get_data(host_id: uuid.UUID) -> List[Dict]:
|
||||
async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
for key, value in item.items():
|
||||
if 'name_embedding' not in key.lower():
|
||||
if key == 'relationship' and value is not None:
|
||||
# 只保留relationship的指定字段
|
||||
rel_filtered = {}
|
||||
if hasattr(value, 'get'):
|
||||
rel_filtered['run_id'] = value.get('run_id')
|
||||
rel_filtered['statement'] = value.get('statement')
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
|
||||
# 直接将字典添加到列表中
|
||||
neo4j_databasets.append(filtered_item)
|
||||
return neo4j_databasets
|
||||
async def get_data_statement( result):
|
||||
neo4j_databasets=[]
|
||||
for i in result:
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
|
||||
# print(f"data:\n{data}")
|
||||
# 解析,提取为字典的列表
|
||||
results = await _load_(data)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,19 +1,222 @@
|
||||
你将收到一组记忆对象:{{ evaluate_data }}。
|
||||
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数:
|
||||
原本的输入句子:{{statement_databasets}}
|
||||
需要检测冲突对象:{{ evaluate_data }}
|
||||
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID)
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false)
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
你的任务是:
|
||||
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
|
||||
## 冲突定义
|
||||
|
||||
### 时间冲突
|
||||
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
|
||||
|
||||
1. **同一活动的时间冲突**:
|
||||
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球")
|
||||
- 同一用户在同一时间段内被记录进行不同的互斥活动
|
||||
|
||||
2. **时间逻辑错误**:
|
||||
- expired_at 早于 created_at
|
||||
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟)
|
||||
|
||||
3. **日期属性冲突**:
|
||||
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号")
|
||||
4.存在明确先后约束 A -> B,但 t(A) > t(B)
|
||||
-例:入学时间晚于毕业时间。
|
||||
-处理:标记异常、降权、触发逻辑反思或人工审查。
|
||||
5.时间属性冲突
|
||||
-单值日期属性出现多值(生日、入职日期)
|
||||
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
|
||||
6.互斥重叠冲突
|
||||
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
|
||||
-处理:证据仲裁、保留多版本(active + candidate)。
|
||||
|
||||
|
||||
|
||||
### 事实冲突
|
||||
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
|
||||
|
||||
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
|
||||
### 混合冲突检测
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
|
||||
检测任何逻辑上不一致或相互矛盾的记录
|
||||
## 记忆审核定义
|
||||
|
||||
### 隐私信息检测(隐私冲突)
|
||||
当memory_verify为true时,需要额外检测包含个人隐私信息的记录:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私检测原则
|
||||
- 检测description、entity1_name、entity2_name等字段中的隐私信息
|
||||
- 识别数字模式(如手机号11位数字、身份证18位等)
|
||||
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
|
||||
- 检测敏感实体类型和关系
|
||||
|
||||
## 冲突检测原则
|
||||
|
||||
**全面检测**:不区分冲突类型,检测所有可能的冲突
|
||||
**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段
|
||||
**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录
|
||||
**语义分析**:分析description字段的语义相似性和冲突性
|
||||
**时间逻辑**:检查时间字段的逻辑一致性
|
||||
**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录
|
||||
|
||||
## 不符合冲突检测
|
||||
-称呼
|
||||
## 重要检测示例
|
||||
|
||||
### 冲突检测示例
|
||||
- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号)
|
||||
- 同一实体的重复定义但描述不同
|
||||
- 同一关系的不同表述但含义冲突
|
||||
- 任何逻辑上不可能同时为真的记录
|
||||
|
||||
### 隐私信息检测示例
|
||||
- 包含手机号的记录:"用户的手机号是13812345678"
|
||||
- 包含身份证的记录:"身份证号码为110101199001011234"
|
||||
- 包含银行卡的记录:"银行卡号6222021234567890"
|
||||
- 包含社交账号的记录:"微信号是user123456"
|
||||
- 包含敏感信息的实体名称或描述
|
||||
|
||||
## 输出要求
|
||||
|
||||
**关键原则**:
|
||||
1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录
|
||||
2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中
|
||||
3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中
|
||||
4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组
|
||||
5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null
|
||||
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
|
||||
7. 不输出conflict_memory字段
|
||||
|
||||
**处理逻辑**:
|
||||
- 首先进行冲突检测,将冲突记录加入data数组
|
||||
- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组
|
||||
- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果
|
||||
- 最终data数组包含所有冲突记录和隐私信息记录(去重)
|
||||
- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果
|
||||
- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
|
||||
## 记忆质量评估定义
|
||||
|
||||
### 质量评估标准
|
||||
当quality_assessment为true时,需要对记忆数据进行质量评估:
|
||||
|
||||
1. **数据完整性**:
|
||||
- 检查必要字段是否完整(entity1_name、entity2_name、description等)
|
||||
- 检查关系描述是否清晰明确
|
||||
- 检查时间字段的有效性
|
||||
|
||||
2. **重复字段检测**:
|
||||
- 识别相同或高度相似的记录
|
||||
- 检测冗余的实体关系
|
||||
- 分析描述内容的重复度
|
||||
|
||||
3. **无意义字段检测**:
|
||||
- 识别空值、无效值或占位符内容
|
||||
- 检测过于简单或无信息量的描述
|
||||
- 识别格式错误或不规范的数据
|
||||
|
||||
4. **上下文依赖性**:
|
||||
- 评估记录是否需要额外上下文才能理解
|
||||
- 检查实体名称的明确性
|
||||
- 分析关系描述的自包含性
|
||||
|
||||
### 质量评估输出
|
||||
- **质量百分比**:基于上述标准计算的整体质量分数(0-100)
|
||||
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
|
||||
|
||||
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
{
|
||||
"data": [ ...与输入同结构的记忆对象数组... ],
|
||||
"conflict": true 或 false,
|
||||
"conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null
|
||||
"data": [
|
||||
{
|
||||
"entity1_name": "实体1名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间戳",
|
||||
"expired_at": "过期时间戳",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": "关系对象",
|
||||
"entity2_name": "实体2名称",
|
||||
"entity2": "实体2对象"
|
||||
}
|
||||
],
|
||||
"conflict": true或false,
|
||||
"quality_assessment": {
|
||||
"score": 质量百分比数字,
|
||||
"summary": "质量概述文本"
|
||||
} 或 null,
|
||||
"memory_verify": {
|
||||
"has_privacy": true或false,
|
||||
"privacy_types": ["检测到的隐私信息类型列表"],
|
||||
"summary": "隐私检测结果概述"
|
||||
} 或 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。
|
||||
- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。
|
||||
- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。
|
||||
|
||||
### memory_verify字段说明
|
||||
当memory_verify为true时,需要输出隐私检测结果:
|
||||
- **has_privacy**: 布尔值,表示是否检测到隐私信息
|
||||
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"])
|
||||
- **summary**: 字符串,简要描述隐私检测结果
|
||||
|
||||
当memory_verify为false时,memory_verify字段输出null。
|
||||
|
||||
### memory_verify字段示例
|
||||
|
||||
**示例1:检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": true,
|
||||
"privacy_types": ["手机号码", "身份证信息"],
|
||||
"summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码"
|
||||
}
|
||||
```
|
||||
|
||||
**示例2:未检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": false,
|
||||
"privacy_types": [],
|
||||
"summary": "未检测到隐私信息"
|
||||
}
|
||||
```
|
||||
|
||||
**示例3:memory_verify为false时**
|
||||
```json
|
||||
"memory_verify": null
|
||||
```
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -1,23 +1,300 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j)
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
|
||||
需要检测冲突对象:{{ statement_databasets }}
|
||||
以及需要识别的冲突对象为:{{ baseline }}
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
|
||||
角色:
|
||||
- 你是数据领域中解决数据冲突的专家
|
||||
|
||||
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
|
||||
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
|
||||
|
||||
**处理模式**:
|
||||
- 当memory_verify为false时:仅处理数据冲突
|
||||
- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏
|
||||
|
||||
## 分组处理原则
|
||||
|
||||
**冲突类型识别与分组**:
|
||||
1. **日期冲突**:
|
||||
1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号),
|
||||
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
|
||||
3. **事实属性冲突**:
|
||||
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
|
||||
|
||||
**分组输出要求**:
|
||||
- 每种冲突类型生成一个独立的reflexion_result对象
|
||||
- 同一类型的多个冲突记录归并到一个结果中
|
||||
- 不同类型的冲突分别处理,各自生成独立结果
|
||||
|
||||
## 冲突类型定义
|
||||
|
||||
### 时间冲突(TIME)
|
||||
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
|
||||
|
||||
### 事实冲突(FACT)
|
||||
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
|
||||
### 混合冲突(HYBRID)
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
|
||||
{% if memory_verify %}
|
||||
## 隐私信息处理(memory_verify为true时启用)
|
||||
|
||||
### 隐私信息识别
|
||||
需要识别并处理以下类型的隐私信息:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私数据脱敏规则
|
||||
对于检测到的隐私信息,按以下规则进行脱敏处理:
|
||||
|
||||
**数字类隐私信息脱敏**:
|
||||
- 保留前三位和后四位,中间用*代替
|
||||
- 示例:手机号13812345678 → 138****5678
|
||||
- 示例:身份证110101199001011234 → 110***********1234
|
||||
- 示例:银行卡6222021234567890 → 622***********7890
|
||||
|
||||
**文本类隐私信息脱敏**:
|
||||
- 社交账号:保留前三后四位字符,中间用*代替
|
||||
- 示例:微信号user123456 → use****3456
|
||||
- 示例:邮箱zhang.san@example.com → zha****@example.com
|
||||
|
||||
**脱敏处理字段**:
|
||||
- name字段:如包含隐私信息需脱敏
|
||||
- entity1_name字段:如包含隐私信息需脱敏
|
||||
- entity2_name字段:如包含隐私信息需脱敏
|
||||
- description字段:如包含隐私信息需脱敏
|
||||
{% endif %}
|
||||
|
||||
## 工作步骤
|
||||
|
||||
### 第一步:分析冲突类型匹配
|
||||
首先判断输入的冲突数据是否符合baseline要求的类型:
|
||||
|
||||
**类型匹配规则**:
|
||||
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
|
||||
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
|
||||
|
||||
**类型识别**:
|
||||
- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等)
|
||||
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
|
||||
|
||||
**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null)
|
||||
|
||||
### 第二步:筛选并分组冲突数据
|
||||
按冲突类型对数据进行分组:
|
||||
|
||||
**分组策略**:
|
||||
1. **时间冲突组**:筛选涉及用户时间的所有记录
|
||||
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
|
||||
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
|
||||
4. **其他冲突组**:其他类型的冲突记录
|
||||
|
||||
**筛选条件**:
|
||||
- 只处理与baseline匹配的冲突类型
|
||||
- 相同entity1_name但entity2_name不同的记录
|
||||
- 相同关系但描述矛盾的记录
|
||||
- 时间逻辑不一致的记录
|
||||
|
||||
### 第三步:冲突解决策略
|
||||
** 不可以解决的冲突情况
|
||||
1. 数据被判定为正确的情况下,不可以进行修改
|
||||
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
|
||||
|
||||
**智能解决策略**:
|
||||
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
|
||||
2. **判断正确答案是否存在**:
|
||||
- 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00)
|
||||
- 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改
|
||||
- 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息
|
||||
|
||||
{% if memory_verify %}
|
||||
**隐私处理集成**:
|
||||
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
|
||||
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
|
||||
- 在change字段中记录隐私脱敏的变更
|
||||
{% endif %}
|
||||
|
||||
**具体处理规则**:
|
||||
|
||||
**情况1:正确答案存在于data中**
|
||||
- 保留正确的记录不变
|
||||
- 基于时间关系的冲突:
|
||||
需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 基于事实的关系冲突
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
|
||||
|
||||
**情况2:正确答案不存在于data中**
|
||||
- 选择最合适的记录进行修改
|
||||
- 更新该记录的相关字段:
|
||||
- description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
|
||||
- name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
|
||||
|
||||
**重要原则**:
|
||||
- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据
|
||||
- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录
|
||||
- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值
|
||||
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
|
||||
- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %}
|
||||
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 如果冲突类型与baseline不匹配,resolved必须设为null
|
||||
- reflexion.reason说明类型不匹配的原因
|
||||
- reflexion.solution说明无需处理
|
||||
|
||||
### 第四步:输出解决方案
|
||||
|
||||
## 输出要求
|
||||
**嵌套字段映射**(系统会自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 响应必须是与此确切模式匹配的有效JSON对象
|
||||
- 不要在JSON之前或之后包含任何文本
|
||||
|
||||
JSON格式要求:
|
||||
1. JSON结构仅使用标准ASCII双引号(")
|
||||
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
|
||||
3. 确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4. JSON字符串值中不包括换行符
|
||||
5. 不允许输出```json```相关符号
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
|
||||
**输出格式:按冲突类型分组的列表**
|
||||
{
|
||||
"conflict": 与输入同结构,包含 data 与 conflict_memory,
|
||||
"reflexion": { "reason": string, "solution": string },
|
||||
"resolved": {
|
||||
"original_memory_id": 被设为失效的记忆 id,
|
||||
"resolved_memory": 完整的设为失效后的记忆对象
|
||||
}
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [该冲突类型相关的数据记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "该冲突类型的原因分析",
|
||||
"solution": "该冲突类型的解决方案"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
"resolved_memory": {
|
||||
"entity1_name": "实体1名称",
|
||||
"entity2_name": "实体2名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间",
|
||||
"expired_at": "过期时间",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": {},
|
||||
"entity2": {...}
|
||||
},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**示例:多种冲突类型的输出**
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [生日冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期",
|
||||
"solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"expired_at": "2025-12-16T12:00:00"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
},
|
||||
{
|
||||
"conflict": {
|
||||
"data": [篮球时间冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
|
||||
"solution": "保留最可信的时间记录,将冲突记录设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "另一个记录ID",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
|
||||
{"entity2_name": "周六"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- 当 conflict 为 false 时,resolved 必须为 null。
|
||||
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
|
||||
- 只输出 JSON,不要添加解释或多余文本
|
||||
- 使用标准双引号,必要时对内部引号进行转义
|
||||
- 字段名与结构必须与给定模式一致
|
||||
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
|
||||
- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中
|
||||
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
|
||||
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
|
||||
- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值
|
||||
- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null
|
||||
- 如果与baseline不匹配的冲突类型,不要在results中包含该类型
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -7,36 +7,50 @@ from typing import List, Dict, Any
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate.jinja2 template.
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: The data to evaluate
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
|
||||
memory_verify: Whether to enable memory verification for privacy detection
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
evaluate_data=evaluate_data,
|
||||
json_schema=schema,
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the extract_temporal.jinja2 template.
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
data: The data to reflex on.
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict resolution.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema)
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
@@ -11,50 +10,53 @@ class DataConfig(Base):
|
||||
|
||||
# 主键
|
||||
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
|
||||
|
||||
|
||||
# 基本信息
|
||||
config_name = Column(String, nullable=False, comment="配置名称")
|
||||
config_desc = Column(String, nullable=True, comment="配置描述")
|
||||
|
||||
|
||||
# 组织信息
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
|
||||
group_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
|
||||
# 模型选择(从workspace继承)
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
llm = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
|
||||
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
|
||||
|
||||
|
||||
# 阈值配置 (0-1 之间的浮点数)
|
||||
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
|
||||
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
|
||||
t_overall = Column(Float, default=0.8, comment="综合阈值")
|
||||
|
||||
|
||||
# 状态配置
|
||||
state = Column(Boolean, default=False, comment="配置使用状态")
|
||||
|
||||
|
||||
# 分块策略
|
||||
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
|
||||
|
||||
|
||||
# 剪枝配置
|
||||
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
|
||||
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound")
|
||||
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)")
|
||||
|
||||
|
||||
# 自我反思配置
|
||||
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
|
||||
iteration_period = Column(String, default="3", comment="反思迭代周期")
|
||||
reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部")
|
||||
baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实")
|
||||
|
||||
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
|
||||
memory_verify = Column(Boolean, default=True, comment="记忆验证")
|
||||
quality_assessment = Column(Boolean, default=True, comment="质量评估")
|
||||
|
||||
# 遗忘引擎配置
|
||||
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
|
||||
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
|
||||
@@ -62,7 +64,7 @@ class DataConfig(Base):
|
||||
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数")
|
||||
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数")
|
||||
offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数")
|
||||
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
@@ -14,6 +14,7 @@ class EndUser(Base):
|
||||
other_id = Column(String, nullable=True) # Store original user_id
|
||||
other_name = Column(String, default="", nullable=False)
|
||||
other_address = Column(String, default="", nullable=False)
|
||||
reflection_time = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
|
||||
@@ -16,48 +16,46 @@ import uuid
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
TABLE_NAME = "data_config"
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
- SQLAlchemy ORM 数据库操作
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
|
||||
|
||||
# ==================== Neo4j Cypher 查询常量 ====================
|
||||
|
||||
|
||||
# Dialogue count by group
|
||||
SEARCH_FOR_DIALOGUE = """
|
||||
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Chunk count by group
|
||||
SEARCH_FOR_CHUNK = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Statement count by group
|
||||
SEARCH_FOR_STATEMENT = """
|
||||
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# ExtractedEntity count by group
|
||||
SEARCH_FOR_ENTITY = """
|
||||
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# All counts by label and total
|
||||
SEARCH_FOR_ALL = """
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
@@ -70,7 +68,7 @@ class DataConfigRepository:
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
@@ -86,7 +84,7 @@ class DataConfigRepository:
|
||||
n.user_id AS user_id,
|
||||
n.id AS id
|
||||
"""
|
||||
|
||||
|
||||
# Edges between extracted entities within group/app/user
|
||||
SEARCH_FOR_EDGES = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -102,7 +100,7 @@ class DataConfigRepository:
|
||||
r.statement_id AS statement_id,
|
||||
r.statement AS statement
|
||||
"""
|
||||
|
||||
|
||||
# Entity graph within group (source node, edge, target node)
|
||||
SEARCH_FOR_ENTITY_GRAPH = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -135,22 +133,106 @@ class DataConfigRepository:
|
||||
id: m.id
|
||||
} AS targetNode
|
||||
"""
|
||||
|
||||
|
||||
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
|
||||
|
||||
@staticmethod
|
||||
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
|
||||
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
**kwargs: 反思配置参数
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
||||
|
||||
key_where = "config_id = :config_id"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": config_id,
|
||||
}
|
||||
|
||||
# 反思配置字段映射
|
||||
mapping = {
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
"reflection_model_id": "reflection_model_id",
|
||||
"memory_verify": "memory_verify",
|
||||
"quality_assessment": "quality_assessment",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
if api_field in kwargs and kwargs[api_field] is not None:
|
||||
set_fields.append(f"{db_col} = :{api_field}")
|
||||
params[api_field] = kwargs[api_field]
|
||||
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id "
|
||||
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
|
||||
)
|
||||
params = {"config_id": config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at "
|
||||
f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC"
|
||||
)
|
||||
params = {"workspace_id": workspace_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
|
||||
"""创建数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
params: 配置参数创建模型
|
||||
|
||||
|
||||
Returns:
|
||||
DataConfig: 创建的配置对象
|
||||
"""
|
||||
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = DataConfig(
|
||||
config_name=params.config_name,
|
||||
@@ -162,37 +244,37 @@ class DataConfigRepository:
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
|
||||
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
|
||||
"""更新基础配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.config_name is not None:
|
||||
@@ -201,44 +283,44 @@ class DataConfigRepository:
|
||||
if update.config_desc is not None:
|
||||
db_config.config_desc = update.config_desc
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
|
||||
"""更新记忆萃取引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 萃取配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
@@ -268,50 +350,50 @@ class DataConfigRepository:
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
|
||||
"""更新遗忘引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.lambda_time is not None:
|
||||
@@ -323,40 +405,40 @@ class DataConfigRepository:
|
||||
if update.offset is not None:
|
||||
db_config.offset = update.offset
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
@@ -379,62 +461,62 @@ class DataConfigRepository:
|
||||
"reflexion_range": db_config.reflexion_range,
|
||||
"baseline": db_config.baseline,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取遗忘配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 遗忘配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"lambda_time": db_config.lambda_time,
|
||||
"lambda_mem": db_config.lambda_mem,
|
||||
"offset": db_config.offset,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
|
||||
"""根据ID获取数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
else:
|
||||
@@ -443,60 +525,60 @@ class DataConfigRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
|
||||
Returns:
|
||||
List[DataConfig]: 配置列表
|
||||
"""
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
query = db.query(DataConfig)
|
||||
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(DataConfig.workspace_id == workspace_id)
|
||||
|
||||
|
||||
configs = query.order_by(desc(DataConfig.updated_at)).all()
|
||||
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, config_id: int) -> bool:
|
||||
"""删除数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,配置不存在返回False
|
||||
"""
|
||||
db_logger.debug(f"删除数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={config_id}")
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_config)
|
||||
db.commit()
|
||||
|
||||
|
||||
db_logger.info(f"数据配置删除成功: config_id={config_id}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
|
||||
|
||||
@@ -746,3 +746,57 @@ DETACH DELETE losing
|
||||
|
||||
RETURN count(losing) as deleted
|
||||
"""
|
||||
|
||||
neo4j_statement_part = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id,
|
||||
n.created_at as statement_created_at
|
||||
|
||||
'''
|
||||
neo4j_statement_all = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id
|
||||
|
||||
'''
|
||||
neo4j_query_part = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
neo4j_query_all = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
|
||||
|
||||
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from app.repositories import Neo4jConnector
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||||
"""
|
||||
Update Neo4j data based on query criteria and update parameters
|
||||
|
||||
Args:
|
||||
neo4j_dict_data: find
|
||||
update_databases: update
|
||||
"""
|
||||
try:
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
params = {}
|
||||
|
||||
for key, value in neo4j_dict_data.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
where_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
# 构建SET条件
|
||||
set_conditions = []
|
||||
for key, value in update_databases.items():
|
||||
if value is not None:
|
||||
param_name = f"update_{key}"
|
||||
set_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
set_clause = ", ".join(set_conditions)
|
||||
|
||||
if not set_clause:
|
||||
print("警告: 没有需要更新的字段")
|
||||
return False
|
||||
|
||||
# 构建Cypher查询
|
||||
cypher_query = f"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE {where_clause}
|
||||
SET {set_clause}
|
||||
RETURN count(e) as updated_count, collect(e.name) as updated_names
|
||||
"""
|
||||
|
||||
print(f"\n执行Cypher查询: {cypher_query}")
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 执行更新
|
||||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||
|
||||
if result:
|
||||
updated_count = result[0].get('updated_count', 0)
|
||||
updated_names = result[0].get('updated_names', [])
|
||||
print(f"成功更新 {updated_count} 个节点")
|
||||
if updated_names:
|
||||
print(f"更新的实体名称: {updated_names}")
|
||||
return updated_count > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def map_field_names(data_dict):
|
||||
mapped_dict = {}
|
||||
has_name_field = False
|
||||
|
||||
# 第一遍:检查是否有name相关字段
|
||||
for key, value in data_dict.items():
|
||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||
has_name_field = True
|
||||
break
|
||||
|
||||
print(f"字段检查: has_name_field = {has_name_field}")
|
||||
|
||||
# 第二遍:根据规则映射和过滤字段
|
||||
for key, value in data_dict.items():
|
||||
if key == 'entity2.name' or key == 'entity2_name':
|
||||
# 将 entity2.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.name' or key == 'entity1_name':
|
||||
# 将 entity1.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.description':
|
||||
# 将 entity1.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'entity2.description':
|
||||
# 将 entity2.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'relationship_type':
|
||||
# 跳过relationship_type字段
|
||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||
continue
|
||||
elif key == 'entity1_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity1_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 如果没有name字段,保留entity1_name
|
||||
mapped_dict[key] = value
|
||||
print(f"字段保留: {key}")
|
||||
elif key == 'entity2_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity2_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 即使没有name字段,也不使用entity2_name(根据需求)
|
||||
print(f"字段过滤: 跳过不推荐的字段 '{key}'")
|
||||
continue
|
||||
elif '.' not in key:
|
||||
# 不包含点号的其他字段直接保留
|
||||
mapped_dict[key] = value
|
||||
else:
|
||||
# 其他包含点号的字段跳过并警告
|
||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||
|
||||
print(f"字段映射结果: {mapped_dict}")
|
||||
return mapped_dict
|
||||
async def neo4j_data(solved_data):
|
||||
"""
|
||||
Process the resolved data and update the Neo4j database
|
||||
Args:
|
||||
Solved_data: Solution Data List
|
||||
Returns:
|
||||
Int: Number of successfully updated records
|
||||
"""
|
||||
success_count = 0
|
||||
|
||||
for i in solved_data:
|
||||
neo4j_dict_data = {}
|
||||
update_databases = {}
|
||||
results = i['results']
|
||||
for data in results:
|
||||
resolved = data.get('resolved')
|
||||
if not resolved:
|
||||
print("跳过:resolved为None")
|
||||
continue
|
||||
|
||||
try:
|
||||
change_list = resolved.get('change', [])
|
||||
except (AttributeError, TypeError):
|
||||
change_list = []
|
||||
|
||||
if change_list == []:
|
||||
print("跳过:change_list为空")
|
||||
continue
|
||||
|
||||
if change_list and len(change_list) > 0:
|
||||
change = change_list[0]
|
||||
print(f"change: {change}")
|
||||
field_data = change.get('field', [])
|
||||
print(f"field_data: {field_data}")
|
||||
print(f"field_data type: {type(field_data)}")
|
||||
|
||||
# 字段名映射和过滤函数
|
||||
|
||||
|
||||
# 处理field数据,可能是字典或列表
|
||||
if isinstance(field_data, dict):
|
||||
# 如果是字典,映射字段名后更新
|
||||
mapped_data = map_field_names(field_data)
|
||||
update_databases.update(mapped_data)
|
||||
elif isinstance(field_data, list):
|
||||
# 如果是列表,遍历每个字典并更新
|
||||
for field_item in field_data:
|
||||
if isinstance(field_item, dict):
|
||||
mapped_item = map_field_names(field_item)
|
||||
update_databases.update(mapped_item)
|
||||
else:
|
||||
print(f"警告: field_item不是字典: {field_item}")
|
||||
else:
|
||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
||||
|
||||
if 'entity1_name' in data:
|
||||
data['name'] = data.pop('entity1_name')
|
||||
if 'entity2_name' in data:
|
||||
data.pop('entity2_name', None)
|
||||
|
||||
resolved_memory = resolved.get('resolved_memory', {})
|
||||
|
||||
entity2 = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
entity2 = resolved_memory.get('entity2')
|
||||
|
||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
||||
stat_id = resolved.get('original_memory_id')
|
||||
# 安全地获取description
|
||||
statement_id = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
statement_id = resolved_memory.get('statement_id')
|
||||
|
||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
||||
if statement_id and 'id' not in neo4j_dict_data:
|
||||
neo4j_dict_data['id'] = stat_id
|
||||
neo4j_dict_data['statement_id'] = statement_id
|
||||
else:
|
||||
# 处理original_memory_id,它可能是字符串或字典
|
||||
try:
|
||||
for key, value in resolved_memory.items():
|
||||
if key == 'statement_id':
|
||||
neo4j_dict_data['statement_id'] = value
|
||||
if key == 'description':
|
||||
neo4j_dict_data['description'] = value
|
||||
except AttributeError:
|
||||
neo4j_dict_data=[]
|
||||
|
||||
print(neo4j_dict_data)
|
||||
print(update_databases)
|
||||
if neo4j_dict_data!=[]:
|
||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
||||
success_count += 1
|
||||
|
||||
return success_count
|
||||
|
||||
@@ -13,5 +13,6 @@ class EndUser(BaseModel):
|
||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||
other_address: Optional[str] = Field(description="其他地址", default="")
|
||||
reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now)
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)
|
||||
|
||||
54
api/app/schemas/memory_reflection_schemas.py
Normal file
54
api/app/schemas/memory_reflection_schemas.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OptimizationStrategy(str, Enum):
|
||||
"""优化策略枚举"""
|
||||
SPEED_FIRST = "speed_first"
|
||||
ACCURACY_FIRST = "accuracy_first"
|
||||
BALANCED = "balanced"
|
||||
|
||||
|
||||
class Memory_Reflection(BaseModel):
|
||||
config_id: Optional[int] = None
|
||||
reflectionenabled: bool
|
||||
reflection_period_in_hours: str
|
||||
reflexion_range: str
|
||||
baseline: str
|
||||
reflection_model_id: str
|
||||
memory_verify: bool
|
||||
quality_assessment: bool
|
||||
|
||||
# 新增快速引擎优化参数
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
use_fast_model: Optional[bool] = True
|
||||
enable_caching: Optional[bool] = True
|
||||
enable_streaming: Optional[bool] = True
|
||||
batch_size: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
max_concurrent: Optional[int] = Field(default=5, ge=1, le=20)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class FastReflectionRequest(BaseModel):
|
||||
"""快速反思请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ReflectionBenchmarkRequest(BaseModel):
|
||||
"""反思基准测试请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
iterations: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
所有的内容是放错误地方了,应该放在models
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
from typing import Any, Optional, List, Dict, Literal, Union
|
||||
import time
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
@@ -28,25 +28,48 @@ class Write_UserInput(BaseModel):
|
||||
# ============================================================================
|
||||
class BaseDataSchema(BaseModel):
|
||||
"""Base schema for the data"""
|
||||
id: str = Field(..., description="The unique identifier for the data entry.")
|
||||
statement: str = Field(..., description="The statement text.")
|
||||
group_id: str = Field(..., description="The group identifier.")
|
||||
chunk_id: str = Field(..., description="The chunk identifier.")
|
||||
# 保持原有必需字段为可选,以兼容不同数据源
|
||||
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
||||
statement: Optional[str] = Field(None, description="The statement text.")
|
||||
group_id: Optional[str] = Field(None, description="The group identifier.")
|
||||
chunk_id: Optional[str] = Field(None, description="The chunk identifier.")
|
||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
|
||||
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
|
||||
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
|
||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||
|
||||
# 新增字段以匹配实际输入数据
|
||||
entity1_name: str = Field(..., description="The first entity name.")
|
||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||
statement_id: str = Field(..., description="The statement identifier.")
|
||||
relationship_type: str = Field(..., description="The relationship type.")
|
||||
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
|
||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||
|
||||
|
||||
class QualityAssessmentSchema(BaseModel):
|
||||
"""Schema for memory quality assessment results."""
|
||||
score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).")
|
||||
summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.")
|
||||
|
||||
|
||||
class MemoryVerifySchema(BaseModel):
|
||||
"""Schema for memory privacy verification results."""
|
||||
has_privacy: bool = Field(..., description="Whether privacy information was detected.")
|
||||
privacy_types: List[str] = Field([], description="List of detected privacy information types.")
|
||||
summary: str = Field(..., description="Brief summary of privacy detection results.")
|
||||
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -61,7 +84,6 @@ class ConflictSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel):
|
||||
solution: str = Field(..., description="The solution for the reflexion.")
|
||||
|
||||
|
||||
class ChangeRecordSchema(BaseModel):
|
||||
"""Schema for individual change records"""
|
||||
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.")
|
||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
||||
|
||||
|
||||
class SingleReflexionResultSchema(BaseModel):
|
||||
"""Schema for a single reflexion result item."""
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.")
|
||||
reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||
type: str = Field("reflexion_result", description="The type identifier.")
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the reflexion result data in the reflexion_data.json file."""
|
||||
# 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data.")
|
||||
reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
|
||||
397
api/app/services/memory_reflection_service.py
Normal file
397
api/app/services/memory_reflection_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
记忆反思服务
|
||||
处理反思引擎的调用和执行
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Set
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine import ReflectionConfig, ReflectionEngine
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionRange, ReflectionBaseline
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.models.app_model import App
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
class WorkspaceAppService:
|
||||
"""Workplace Application Service Class """
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_workspace_apps_detailed(self, workspace_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get detailed information of all applications in the workspace
|
||||
|
||||
Args:
|
||||
Workspace_id: Workspace ID
|
||||
|
||||
Returns:
|
||||
Dictionary containing detailed application information
|
||||
"""
|
||||
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
||||
app_ids = [str(app.id) for app in apps]
|
||||
|
||||
apps_detailed_info = []
|
||||
|
||||
for app in apps:
|
||||
app_info = self._build_app_info(app)
|
||||
self._process_app_releases(app, app_info)
|
||||
self._process_end_users(app, app_info)
|
||||
apps_detailed_info.append(app_info)
|
||||
|
||||
return {
|
||||
"status": "成功",
|
||||
"message": f"成功查询到 {len(app_ids)} 个应用及其详细信息",
|
||||
"workspace_id": str(workspace_id),
|
||||
"apps_count": len(app_ids),
|
||||
"app_ids": app_ids,
|
||||
"apps_detailed_info": apps_detailed_info
|
||||
}
|
||||
|
||||
def _build_app_info(self, app: App) -> Dict[str, Any]:
|
||||
"""base_infomation"""
|
||||
return {
|
||||
"id": str(app.id),
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"type": app.type,
|
||||
"status": app.status,
|
||||
"visibility": app.visibility,
|
||||
"created_at": app.created_at.isoformat() if app.created_at else None,
|
||||
"updated_at": app.updated_at.isoformat() if app.updated_at else None,
|
||||
"releases": [],
|
||||
"data_configs": [],
|
||||
"end_users": []
|
||||
}
|
||||
|
||||
def _process_app_releases(self, app: App, app_info: Dict[str, Any]) -> None:
|
||||
"""Process the release version and configuration information of the application"""
|
||||
app_releases = self.db.query(AppRelease).filter(AppRelease.app_id == app.id).all()
|
||||
|
||||
if not app_releases:
|
||||
return
|
||||
|
||||
processed_configs: Set[str] = set()
|
||||
|
||||
for release in app_releases:
|
||||
memory_content = self._extract_memory_content(release.config)
|
||||
|
||||
|
||||
if memory_content and memory_content in processed_configs:
|
||||
continue
|
||||
|
||||
release_info = {
|
||||
"app_id": str(release.app_id),
|
||||
"config": memory_content
|
||||
}
|
||||
|
||||
|
||||
if memory_content:
|
||||
processed_configs.add(memory_content)
|
||||
data_config_info = self._get_data_config(memory_content)
|
||||
|
||||
if data_config_info:
|
||||
if not any(dc["config_id"] == data_config_info["config_id"] for dc in app_info["data_configs"]):
|
||||
app_info["data_configs"].append(data_config_info)
|
||||
|
||||
app_info["releases"].append(release_info)
|
||||
|
||||
def _extract_memory_content(self, config: Any) -> str:
|
||||
"""Extract memory_comtent from config"""
|
||||
if not config or not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
memory_obj = config.get('memory')
|
||||
if memory_obj and isinstance(memory_obj, dict):
|
||||
return memory_obj.get('memory_content')
|
||||
|
||||
return None
|
||||
|
||||
def _get_data_config(self, memory_content: str) -> Dict[str, Any]:
|
||||
"""Retrieve data_comfig information based on memory_comtent"""
|
||||
try:
|
||||
data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content)
|
||||
data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone()
|
||||
if data_config_result is None:
|
||||
return None
|
||||
|
||||
if data_config_result:
|
||||
return {
|
||||
"config_id": data_config_result.config_id,
|
||||
"enable_self_reflexion": data_config_result.enable_self_reflexion,
|
||||
"iteration_period": data_config_result.iteration_period,
|
||||
"reflexion_range": data_config_result.reflexion_range,
|
||||
"baseline": data_config_result.baseline,
|
||||
"reflection_model_id": data_config_result.reflection_model_id,
|
||||
"memory_verify": data_config_result.memory_verify,
|
||||
"quality_assessment": data_config_result.quality_assessment,
|
||||
"user_id": data_config_result.user_id
|
||||
}
|
||||
except Exception as e:
|
||||
api_logger.warning(f"查询data_config失败,memory_content: {memory_content}, 错误: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None:
|
||||
"""Processing end-user information for applications"""
|
||||
end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all()
|
||||
|
||||
for end_user in end_users:
|
||||
end_user_info = {
|
||||
"id": str(end_user.id),
|
||||
"app_id": str(end_user.app_id)
|
||||
}
|
||||
app_info["end_users"].append(end_user_info)
|
||||
|
||||
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
|
||||
"""
|
||||
Read the reflection time of end users
|
||||
|
||||
Args:
|
||||
End_user_id: End User ID
|
||||
|
||||
Returns:
|
||||
Reflection time or None
|
||||
"""
|
||||
try:
|
||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if end_user:
|
||||
return end_user.reflection_time
|
||||
return None
|
||||
except Exception as e:
|
||||
api_logger.error(f"读取用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def update_end_user_reflection_time(self, end_user_id: str) -> bool:
|
||||
"""
|
||||
Update the reflection time of end users to the current time
|
||||
|
||||
Args:
|
||||
End_user_id: End User ID
|
||||
|
||||
Returns:
|
||||
Is the update successful
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if end_user:
|
||||
end_user.reflection_time = datetime.now()
|
||||
self.db.commit()
|
||||
api_logger.info(f"成功更新用户反思时间,end_user_id: {end_user_id}")
|
||||
return True
|
||||
else:
|
||||
api_logger.warning(f"未找到用户,end_user_id: {end_user_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
|
||||
class MemoryReflectionService:
|
||||
"""Memory reflection service category"""
|
||||
|
||||
def __init__(self,db: Session = Depends(get_db)):
|
||||
self.db=db
|
||||
|
||||
|
||||
async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Starting Reflection from Configuration Data
|
||||
|
||||
Args:
|
||||
config_data: Configure data dictionary, including reflective configuration information
|
||||
end_user_id: end_user_id
|
||||
|
||||
Returns:
|
||||
Reflect on the execution results
|
||||
"""
|
||||
try:
|
||||
config_id = config_data.get("config_id")
|
||||
api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}")
|
||||
|
||||
|
||||
if not config_data.get("enable_self_reflexion", False):
|
||||
return {
|
||||
"status": "跳过",
|
||||
"message": "反思引擎未启用",
|
||||
"config_id": config_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data
|
||||
}
|
||||
|
||||
|
||||
config_data_id=config_data['config_id']
|
||||
reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id)
|
||||
if reflection_config is not None and reflection_config['enable_self_reflexion']:
|
||||
reflection_config= self._create_reflection_config_from_data(reflection_config)
|
||||
iteration_period=reflection_config.iteration_period
|
||||
workspace_service = WorkspaceAppService(self.db)
|
||||
current_reflection_time = workspace_service.get_end_user_reflection_time(end_user_id)
|
||||
|
||||
reflection_time = datetime.fromisoformat(str(current_reflection_time))
|
||||
|
||||
current_time = datetime.now()
|
||||
time_diff = current_time - reflection_time
|
||||
hours_diff = int(time_diff.total_seconds() / 3600)
|
||||
if iteration_period==hours_diff or current_reflection_time is None:
|
||||
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
|
||||
# 3. 执行反思引擎
|
||||
reflection_results = await self._execute_reflection_engine(
|
||||
reflection_config, end_user_id
|
||||
)
|
||||
# 更新反思时间为当前时间
|
||||
update_success = workspace_service.update_end_user_reflection_time(end_user_id)
|
||||
if update_success:
|
||||
api_logger.info(f"成功更新用户 {end_user_id} 的反思时间")
|
||||
else:
|
||||
api_logger.error(f"更新用户 {end_user_id} 的反思时间失败")
|
||||
|
||||
return {
|
||||
"status": "完成",
|
||||
"message": "反思引擎执行完成",
|
||||
"config_id": config_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data,
|
||||
"reflection_results": reflection_results
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "等待中..",
|
||||
"message": "反思引擎未开始执行执",
|
||||
"config_id": config_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data,
|
||||
"reflection_results": ''
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
config_id = config_data.get("config_id", "unknown")
|
||||
api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||
return {
|
||||
"status": "错误",
|
||||
"message": f"启动反思失败: {str(e)}",
|
||||
"config_id": config_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data
|
||||
}
|
||||
|
||||
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
|
||||
"""Create reflective configuration objects from configuration data"""
|
||||
|
||||
reflexion_range_value = config_data.get("reflexion_range")
|
||||
if reflexion_range_value is None or reflexion_range_value == "":
|
||||
reflexion_range_value = "partial"
|
||||
reflexion_range = ReflectionRange(reflexion_range_value)
|
||||
|
||||
baseline_value = config_data.get("baseline")
|
||||
if baseline_value is None or baseline_value == "":
|
||||
baseline_value = "TIME"
|
||||
baseline = ReflectionBaseline(baseline_value)
|
||||
|
||||
# iteration_period =
|
||||
iteration_period = config_data.get("iteration_period", 24)
|
||||
if isinstance(iteration_period, str):
|
||||
try:
|
||||
iteration_period = int(iteration_period)
|
||||
except (ValueError, TypeError):
|
||||
iteration_period = 24 # 默认24小时
|
||||
|
||||
return ReflectionConfig(
|
||||
enabled=config_data.get("enable_self_reflexion", False),
|
||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||
reflexion_range=reflexion_range,
|
||||
baseline=baseline,
|
||||
memory_verify=config_data.get("memory_verify", False),
|
||||
quality_assessment=config_data.get("quality_assessment", False),
|
||||
model_id=config_data.get("reflection_model_id", "")
|
||||
)
|
||||
|
||||
async def _execute_reflection_engine(
|
||||
self,
|
||||
reflection_config: ReflectionConfig,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute Reflection Engine"""
|
||||
try:
|
||||
# 创建Neo4j连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 创建反思引擎
|
||||
engine = ReflectionEngine(
|
||||
config=reflection_config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=reflection_config.model_id
|
||||
)
|
||||
|
||||
# 执行反思
|
||||
reflection_result = await engine.execute_reflection(user_id)
|
||||
|
||||
return {
|
||||
"success": reflection_result.success,
|
||||
"message": reflection_result.message,
|
||||
"conflicts_found": reflection_result.conflicts_found,
|
||||
"conflicts_resolved": reflection_result.conflicts_resolved,
|
||||
"memories_updated": reflection_result.memories_updated,
|
||||
"execution_time": reflection_result.execution_time,
|
||||
"details": reflection_result.details
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"反思引擎执行失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"反思引擎执行失败: {str(e)}",
|
||||
"conflicts_found": 0,
|
||||
"conflicts_resolved": 0,
|
||||
"memories_updated": 0,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
|
||||
|
||||
class Memory_Reflection_Service:
|
||||
"""Memory Reflection Service - Used for calling the/reflection interface"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.reflection_service = MemoryReflectionService(db)
|
||||
|
||||
async def start_reflection(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Activate the reflection function
|
||||
|
||||
Args:
|
||||
config_data: 配置数据,格式如下:
|
||||
{
|
||||
"config_id": 26,
|
||||
"enable_self_reflexion": true,
|
||||
"iteration_period": "6",
|
||||
"reflexion_range": "partial",
|
||||
"baseline": "TIME",
|
||||
"reflection_model_id": "ea405fa6-c387-4d78-80ab-826d692301b3",
|
||||
"memory_verify": true,
|
||||
"quality_assessment": false,
|
||||
"user_id": null
|
||||
}
|
||||
end_user_id: end_user_id,example "12a8b235-6eb1-4481-a53c-b77933b5c949"
|
||||
|
||||
Returns:
|
||||
"""
|
||||
api_logger.info(f"Memory_Reflection_Service启动反思,config_id: {config_data.get('config_id')}, end_user_id: {end_user_id}")
|
||||
|
||||
# 调用核心反思服务
|
||||
result = await self.reflection_service.start_reflection_from_data(config_data, end_user_id)
|
||||
|
||||
return result
|
||||
163
api/app/tasks.py
163
api/app/tasks.py
@@ -295,26 +295,6 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
}
|
||||
|
||||
|
||||
def reflection_engine() -> None:
|
||||
"""Empty function placeholder for timed background reflection.
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
import asyncio
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||
def reflection_timer_task() -> None:
|
||||
"""Periodic Celery task that invokes reflection_engine.
|
||||
|
||||
Raises an exception on failure.
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
def check_read_service_task() -> Dict[str, str]:
|
||||
@@ -464,4 +444,147 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"error": str(e),
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True)
|
||||
def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:每30秒运行工作空间反思功能
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
api_logger = get_api_logger()
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
workspaces = db.query(Workspace).all()
|
||||
|
||||
if not workspaces:
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到工作空间",
|
||||
"workspace_count": 0,
|
||||
"reflection_results": []
|
||||
}
|
||||
|
||||
all_reflection_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
# 使用服务类处理复杂查询逻辑
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(str(workspace_id))
|
||||
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
workspace_reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"reflection_count": len(workspace_reflection_results),
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
"reflection_count": 0,
|
||||
"reflection_results": []
|
||||
})
|
||||
|
||||
total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务",
|
||||
"workspace_count": len(workspaces),
|
||||
"total_reflections": total_reflections,
|
||||
"workspace_results": all_reflection_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"reflection_results": []
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
108
api/check_code.py
Executable file
108
api/check_code.py
Executable file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
代码质量检查脚本
|
||||
自动检查代码中的导入错误、未使用变量、语法问题等
|
||||
|
||||
用法:
|
||||
python check_code.py # 检查整个 app/ 目录
|
||||
python check_code.py file1.py file2.py # 检查指定文件
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], description: str) -> tuple[bool, str]:
|
||||
"""运行命令并返回结果"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"🔍 {description}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
|
||||
output = result.stdout + result.stderr
|
||||
success = result.returncode == 0
|
||||
|
||||
if success:
|
||||
print(f"✅ {description} - 通过")
|
||||
else:
|
||||
print(f"❌ {description} - 发现问题")
|
||||
if output:
|
||||
print(output[:2000]) # 只显示前2000字符
|
||||
|
||||
return success, output
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 执行失败: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 获取命令行参数中的文件列表
|
||||
target_files = sys.argv[1:] if len(sys.argv) > 1 else None
|
||||
|
||||
if target_files:
|
||||
# 检查指定文件
|
||||
print(f"🚀 开始代码质量检查 (指定文件: {len(target_files)} 个)...")
|
||||
target_paths = target_files
|
||||
ruff_target = target_files
|
||||
py_compile_files = [f for f in target_files if f.endswith('.py')]
|
||||
else:
|
||||
# 检查整个 app/ 目录
|
||||
print("🚀 开始代码质量检查 (整个 app/ 目录)...")
|
||||
target_paths = ["app/"]
|
||||
ruff_target = ["app/"]
|
||||
py_compile_files = list(Path("app").rglob("*.py"))
|
||||
|
||||
checks = [
|
||||
{
|
||||
"cmd": ["ruff", "check"] + ruff_target + ["--output-format=concise"],
|
||||
"description": "Ruff 代码检查 (导入、语法、风格)",
|
||||
"auto_fix": ["ruff", "check"] + ruff_target + ["--fix", "--unsafe-fixes"],
|
||||
},
|
||||
{
|
||||
"cmd": ["python", "-m", "py_compile"] + [str(f) for f in py_compile_files],
|
||||
"description": "Python 语法检查",
|
||||
"auto_fix": None,
|
||||
},
|
||||
]
|
||||
|
||||
results = []
|
||||
for check in checks:
|
||||
success, output = run_command(check["cmd"], check["description"])
|
||||
results.append(
|
||||
{"name": check["description"], "success": success, "output": output, "auto_fix": check.get("auto_fix")}
|
||||
)
|
||||
|
||||
# 汇总报告
|
||||
print(f"\n{'=' * 60}")
|
||||
print("📊 检查汇总")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
all_passed = True
|
||||
for result in results:
|
||||
status = "✅ 通过" if result["success"] else "❌ 失败"
|
||||
print(f"{status} - {result['name']}")
|
||||
if not result["success"]:
|
||||
all_passed = False
|
||||
if result["auto_fix"]:
|
||||
print(f" 💡 可以运行自动修复: {' '.join(result['auto_fix'])}")
|
||||
|
||||
if all_passed:
|
||||
print("\n🎉 所有检查通过!")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ 发现问题,请查看上面的详细信息")
|
||||
print("\n💡 快速修复命令:")
|
||||
if target_files:
|
||||
print(f" ruff check {' '.join(target_files)} --fix --unsafe-fixes")
|
||||
else:
|
||||
print(" ruff check app/ --fix --unsafe-fixes")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user