[MODIFY] Code optimization
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -20,8 +20,7 @@ examples/
|
||||
.idea
|
||||
|
||||
# Temporary outputs
|
||||
app/core/memory/agent/.DS_Store
|
||||
app/core/memory/src/utils/.DS_Store
|
||||
**/.DS_Store
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
from celery import Celery
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
@@ -12,6 +13,7 @@ celery_app = Celery(
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
)
|
||||
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.user_model import User
|
||||
@@ -10,142 +14,344 @@ from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.api_key_service import ApiKeyService
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
|
||||
logger = get_business_logger()
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def create_api_key(
|
||||
data: api_key_schema.ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
data: api_key_schema.ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建 API Key
|
||||
|
||||
"""
|
||||
创建 API Key
|
||||
|
||||
- 支持三种类型:app/rag/memory
|
||||
- 创建后返回明文 API Key(仅此一次)
|
||||
- 支持设置权限范围、速率限制、配额等
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 创建 API Key
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "create_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"创建API Key失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def list_api_keys(
|
||||
type: api_key_schema.ApiKeyType = Query(None),
|
||||
is_active: bool = Query(None),
|
||||
resource_id: uuid.UUID = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
type: api_key_schema.ApiKeyType = Query(None, description="按类型过滤"),
|
||||
is_active: bool = Query(True, description="按状态过滤"),
|
||||
resource_id: uuid.UUID = Query(None, description="按资源过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""列出 API Keys"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
query = api_key_schema.ApiKeyQuery(
|
||||
type=type,
|
||||
is_active=is_active,
|
||||
resource_id=resource_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||
return success(data=result)
|
||||
"""
|
||||
列出 API Keys
|
||||
|
||||
- 支持多维度过滤
|
||||
- 支持分页
|
||||
- 自动按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
query = api_key_schema.ApiKeyQuery(
|
||||
type=type,
|
||||
is_active=is_active,
|
||||
resource_id=resource_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||
|
||||
logger.info("API Keys 查询成功", extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total_count": result.get("total", 0) if isinstance(result, dict) else 0
|
||||
})
|
||||
|
||||
return success(data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "list_api_keys"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"API Keys 查询失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 详情"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("获取API Key详情成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key"
|
||||
})
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"获取API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def update_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
"""更新 API Key配置"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||
|
||||
logger.info("API Key 更新配置成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "update_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"更新API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(msg="API Key 删除成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("API Key 删除成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(msg="API Key 删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "delete_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"删除API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def regenerate_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""重新生成 API Key
|
||||
"""
|
||||
重新生成 API Key
|
||||
|
||||
- 生成新的 API Key 并返回明文(仅此一次)
|
||||
- 旧的 API Key 立即失效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 重新生成成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
logger.info("API Key 重新生成成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=response_data, msg="API Key 重新生成成功")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "regenerate_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"重新生成API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_stats(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 使用统计"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=stats)
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("API Key stats retrieved successfully", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key_stats"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"获取API Key统计失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}/logs", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_logs(
|
||||
api_key_id: uuid.UUID,
|
||||
start_date: Optional[datetime] = Query(None, description="开始日期"),
|
||||
end_date: Optional[datetime] = Query(None, description="结束日期"),
|
||||
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
|
||||
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取 API Key 使用日志
|
||||
|
||||
- 支持时间范围过滤
|
||||
- 支持状态码和端点过滤
|
||||
- 按时间倒序返回
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证日期范围
|
||||
if start_date and end_date and start_date > end_date:
|
||||
logger.warning("开始日期晚于结束日期", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat()
|
||||
})
|
||||
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 验证状态码
|
||||
if status_code and (status_code < 100 or status_code > 599):
|
||||
logger.warning("查询无效的状态码", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"status_code": status_code
|
||||
})
|
||||
raise BusinessException("无效的HTTP状态码", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 构建过滤条件
|
||||
filters = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"status_code": status_code,
|
||||
"endpoint": endpoint
|
||||
}
|
||||
|
||||
# 调用服务层获取日志
|
||||
result = ApiKeyService.get_logs(
|
||||
db, api_key_id, workspace_id, filters, page, pagesize
|
||||
)
|
||||
|
||||
logger.info("API Key 日志查询成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"filters": {k: str(v) if v else None for k, v in filters.items()}
|
||||
})
|
||||
|
||||
return success(data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key_logs"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"API Key 日志查询失败: {str(e)}")
|
||||
|
||||
@@ -121,7 +121,7 @@ def delete_app(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求删除应用",
|
||||
"用户请求删除应用",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
@@ -151,7 +151,7 @@ def copy_app(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求复制应用",
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
"source_app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
@@ -432,7 +432,7 @@ async def draft_run(
|
||||
|
||||
# 非流式返回
|
||||
logger.debug(
|
||||
f"开始非流式试运行",
|
||||
"开始非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -456,7 +456,7 @@ async def draft_run(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"试运行返回结果",
|
||||
"试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
|
||||
@@ -466,11 +466,11 @@ async def draft_run(
|
||||
# 验证结果
|
||||
try:
|
||||
validated_result = app_schema.DraftRunResponse.model_validate(result)
|
||||
logger.debug(f"结果验证成功")
|
||||
logger.debug("结果验证成功")
|
||||
return success(data=validated_result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"结果验证失败",
|
||||
"结果验证失败",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"error_type": str(type(e)),
|
||||
@@ -496,7 +496,7 @@ async def draft_run(
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
f"开始多智能体流式试运行",
|
||||
"开始多智能体流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -530,7 +530,7 @@ async def draft_run(
|
||||
|
||||
# 4. 非流式返回
|
||||
logger.debug(
|
||||
f"开始多智能体非流式试运行",
|
||||
"开始多智能体非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -542,7 +542,7 @@ async def draft_run(
|
||||
result = await multiservice.run(app_id, multi_agent_request)
|
||||
|
||||
logger.debug(
|
||||
f"多智能体试运行返回结果",
|
||||
"多智能体试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
@@ -599,7 +599,7 @@ async def draft_run_compare(
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比试运行",
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(payload.models),
|
||||
@@ -705,7 +705,7 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比完成",
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
|
||||
@@ -178,7 +178,7 @@ async def get_chunks(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document chunk query")
|
||||
api_logger.debug("Start executing document chunk query")
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
|
||||
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
|
||||
@@ -213,7 +213,9 @@ async def create_chunk(
|
||||
"""
|
||||
create chunk
|
||||
"""
|
||||
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||||
# Obtain the actual content
|
||||
content = create_data.chunk_content
|
||||
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={content}, username: {current_user.username}")
|
||||
|
||||
# 1. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
@@ -250,7 +252,7 @@ async def create_chunk(
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||
# 3. Segmented vector storage
|
||||
vector_service.add_chunks([chunk])
|
||||
|
||||
@@ -305,7 +307,9 @@ async def update_chunk(
|
||||
"""
|
||||
Update document chunk content
|
||||
"""
|
||||
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
|
||||
# Obtain the actual content
|
||||
content = update_data.chunk_content
|
||||
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={content}, username: {current_user.username}")
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -318,7 +322,7 @@ async def update_chunk(
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
chunk = items[0]
|
||||
chunk.page_content = update_data.content
|
||||
chunk.page_content = content
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
|
||||
@@ -78,7 +78,7 @@ async def get_documents(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document paging query")
|
||||
api_logger.debug("Start executing document paging query")
|
||||
total, items = document_service.get_documents_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -66,7 +66,7 @@ async def get_files(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing file paging query")
|
||||
api_logger.debug("Start executing file paging query")
|
||||
total, items = file_service.get_files_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -74,8 +74,6 @@ async def get_knowledges(
|
||||
filters = [
|
||||
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
|
||||
]
|
||||
if parent_id:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||
|
||||
# Keyword search (fuzzy matching of knowledge base name)
|
||||
if keywords:
|
||||
@@ -91,9 +89,14 @@ async def get_knowledges(
|
||||
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
|
||||
else:
|
||||
filters.append(knowledge_model.Knowledge.status != 2)
|
||||
if parent_id:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||
else:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == current_user.current_workspace_id)
|
||||
filters.append(knowledge_model.Knowledge.permission_id != knowledge_model.PermissionType.Memory)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base paging query")
|
||||
api_logger.debug("Start executing knowledge base paging query")
|
||||
total, items = knowledge_service.get_knowledges_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -58,7 +58,7 @@ async def get_knowledgeshares(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base sharing and paging query")
|
||||
api_logger.debug("Start executing knowledge base sharing and paging query")
|
||||
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -54,7 +54,7 @@ def validate_config_id(config_id: int, db: Session) -> int:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info(f"config_id is required but was not provided")
|
||||
api_logger.info("config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
@@ -257,7 +257,7 @@ async def write_server(
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
|
||||
@@ -2,11 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_agent_schema import End_User_Information
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
|
||||
@@ -41,6 +43,56 @@ def get_workspace_total_end_users(
|
||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||
return success(data=total_end_users, msg="用户数量获取成功")
|
||||
|
||||
@router.post("/update/end_users", response_model=ApiResponse)
|
||||
async def update_workspace_end_users(
|
||||
user_input: End_User_Information,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
更新工作空间的宿主信息
|
||||
"""
|
||||
username = user_input.end_user_name # 要更新的用户名
|
||||
end_user_input_id = user_input.id # 宿主ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
|
||||
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
|
||||
|
||||
try:
|
||||
# 导入更新函数
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
|
||||
# 转换 end_user_id 为 UUID 类型
|
||||
end_user_uuid = uuid.UUID(end_user_input_id)
|
||||
|
||||
# 直接更新数据库中的 other_name 字段
|
||||
updated_count = update_end_user_other_name(
|
||||
db=db,
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=username
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"updated_count": updated_count,
|
||||
"end_user_id": end_user_input_id,
|
||||
"updated_other_name": username
|
||||
},
|
||||
msg=f"成功更新 {updated_count} 个宿主的信息"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新宿主信息失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新宿主信息失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
@@ -53,6 +105,8 @@ async def get_workspace_end_users(
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
@@ -61,14 +115,21 @@ async def get_workspace_end_users(
|
||||
)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
result.append(
|
||||
{
|
||||
'end_user':end_user,
|
||||
'memory_num':memory_num
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -203,7 +264,7 @@ def get_workspace_memory_list(
|
||||
current_user=current_user,
|
||||
limit=limit
|
||||
)
|
||||
api_logger.info(f"成功获取记忆列表")
|
||||
api_logger.info("成功获取记忆列表")
|
||||
return success(data=memory_list, msg="记忆列表获取成功")
|
||||
|
||||
|
||||
@@ -354,7 +415,7 @@ async def get_chunk_insight(
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk洞察")
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
|
||||
@@ -469,7 +530,7 @@ async def dashboard_data(
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info(f"成功获取neo4j_data")
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
# 如果 storage_type 为 'rag',获取 rag_data
|
||||
elif storage_type == 'rag':
|
||||
@@ -503,9 +564,9 @@ async def dashboard_data(
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info(f"成功获取rag_data")
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
api_logger.info(f"成功获取dashboard整合数据")
|
||||
api_logger.info("成功获取dashboard整合数据")
|
||||
return success(data=result, msg="Dashboard数据获取成功")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -62,7 +65,7 @@ async def get_storage_info(
|
||||
Returns:
|
||||
Storage information
|
||||
"""
|
||||
api_logger.info(f"Storage info requested ")
|
||||
api_logger.info("Storage info requested ")
|
||||
try:
|
||||
result = await memory_storage_service.get_storage_info()
|
||||
return success(data=result)
|
||||
@@ -139,6 +142,7 @@ def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -151,7 +155,7 @@ def create_config(
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
payload.workspace_id = workspace_id
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except Exception as e:
|
||||
@@ -163,6 +167,7 @@ def create_config(
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -173,7 +178,7 @@ def delete_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
except Exception as e:
|
||||
@@ -184,6 +189,7 @@ def delete_config(
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -194,7 +200,7 @@ def update_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -206,6 +212,7 @@ def update_config(
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -216,7 +223,7 @@ def update_config_extracted(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_extracted(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -229,6 +236,7 @@ def update_config_extracted(
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -239,7 +247,7 @@ def update_config_forget(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -251,6 +259,7 @@ def update_config_forget(
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -261,7 +270,7 @@ def read_config_extracted(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_extracted(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
@@ -272,6 +281,7 @@ def read_config_extracted(
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -282,7 +292,7 @@ def read_config_forget(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
@@ -292,6 +302,7 @@ def read_config_forget(
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -302,7 +313,7 @@ def read_all_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
|
||||
result = svc.get_all(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
@@ -315,6 +326,7 @@ def read_all_config(
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
@@ -330,7 +342,7 @@ async def pilot_run(
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = await svc.pilot_run(payload)
|
||||
return success(data=result, msg="试运行完成")
|
||||
except ValueError as e:
|
||||
@@ -475,13 +487,13 @@ async def search_for_entity_graph(
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
try:
|
||||
result = await analytics_hot_memory_tags(end_user_id, limit)
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
|
||||
@@ -46,7 +46,8 @@ def get_model_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
@@ -55,7 +56,7 @@ def get_model_list(
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
query = model_schema.ModelConfigQuery(
|
||||
@@ -69,7 +70,7 @@ def get_model_list(
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query)
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
@@ -81,16 +82,17 @@ def get_model_list(
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据ID获取模型配置
|
||||
"""
|
||||
api_logger.info(f"获取模型配置请求: model_id={model_id}")
|
||||
api_logger.info(f"获取模型配置请求: model_id={model_id}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置获取成功: {result_orm.name}")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -116,11 +118,11 @@ async def create_model(
|
||||
- 验证失败时会抛出异常,不会创建配置
|
||||
- 可通过 skip_validation=true 跳过验证
|
||||
"""
|
||||
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
|
||||
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始创建模型配置: {model_data.name}")
|
||||
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
|
||||
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -142,11 +144,11 @@ def update_model(
|
||||
"""
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
|
||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -167,11 +169,11 @@ def delete_model(
|
||||
"""
|
||||
删除模型配置
|
||||
"""
|
||||
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id)
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||
return success(msg="模型配置删除成功")
|
||||
except Exception as e:
|
||||
|
||||
@@ -158,7 +158,7 @@ async def run_multi_agent(
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/test-routing",
|
||||
summary="测试智能路由"
|
||||
summary="测试智能路由(支持 Master Agent 模式)"
|
||||
)
|
||||
async def test_routing(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
@@ -166,19 +166,20 @@ async def test_routing(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""测试智能路由功能
|
||||
"""测试智能路由功能(重构版 - 支持 Master Agent)
|
||||
|
||||
支持三种路由模式:
|
||||
- keyword: 仅使用关键词路由
|
||||
- llm: 使用 LLM 路由(需要提供 routing_model_id)
|
||||
- hybrid: 混合路由(关键词 + LLM)
|
||||
- master_agent: 使用 Master Agent 决策(推荐)
|
||||
- llm_router: 使用旧 LLM 路由器(向后兼容)
|
||||
- rule_only: 仅使用规则路由(最快)
|
||||
|
||||
参数:
|
||||
- message: 测试消息
|
||||
- conversation_id: 会话 ID(可选)
|
||||
- routing_model_id: 路由模型 ID(可选,用于 LLM 路由)
|
||||
- routing_model_id: 路由模型 ID(可选)
|
||||
- use_llm: 是否启用 LLM(默认 False)
|
||||
- keyword_threshold: 关键词置信度阈值(默认 0.8)
|
||||
- force_new: 是否强制重新路由(默认 False)
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.llm_router import LLMRouter
|
||||
@@ -276,7 +277,161 @@ async def test_routing(
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/",
|
||||
"/{app_id}/multi-agent/test-master-agent",
|
||||
summary="测试 Master Agent 决策"
|
||||
)
|
||||
async def test_master_agent(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: multi_agent_schema.RoutingTestRequest = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""测试 Master Agent 的路由决策能力
|
||||
|
||||
这个接口专门用于测试新的 Master Agent 路由器,
|
||||
可以看到 Master Agent 的完整决策过程。
|
||||
|
||||
返回信息包括:
|
||||
- 选中的 Agent
|
||||
- 置信度
|
||||
- 决策理由
|
||||
- 是否需要协作
|
||||
- 路由策略(master_agent / rule_fast_path / fallback)
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.master_agent_router import MasterAgentRouter
|
||||
from app.models import ModelConfig
|
||||
|
||||
# 1. 获取多 Agent 配置
|
||||
service = MultiAgentService(db)
|
||||
config = service.get_config(app_id)
|
||||
|
||||
if not config:
|
||||
return success(
|
||||
data=None,
|
||||
msg="应用未配置多 Agent,无法测试"
|
||||
)
|
||||
|
||||
# 2. 加载 Master Agent
|
||||
from app.models import AppRelease, App
|
||||
|
||||
master_release = db.get(AppRelease, config.master_agent_id)
|
||||
if not master_release:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"Master Agent 发布版本不存在: {config.master_agent_id}"
|
||||
)
|
||||
|
||||
# 获取应用信息
|
||||
app = db.get(App, master_release.app_id)
|
||||
if not app:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"应用不存在: {master_release.app_id}"
|
||||
)
|
||||
|
||||
# 创建 Master Agent 代理对象
|
||||
class AgentConfigProxy:
|
||||
def __init__(self, release, app, config_data):
|
||||
self.id = release.id
|
||||
self.app_id = release.app_id
|
||||
self.app = app
|
||||
self.name = release.name
|
||||
self.description = release.description
|
||||
self.system_prompt = config_data.get("system_prompt")
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
config_data = master_release.config or {}
|
||||
master_agent_config = AgentConfigProxy(master_release, app, config_data)
|
||||
|
||||
# 3. 获取 Master Agent 的模型配置
|
||||
master_model_config = db.get(ModelConfig, master_agent_config.default_model_config_id)
|
||||
if not master_model_config:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"Master Agent 模型配置不存在: {master_agent_config.default_model_config_id}"
|
||||
)
|
||||
|
||||
# 4. 准备子 Agent 信息
|
||||
sub_agents = {}
|
||||
for sub_agent_info in config.sub_agents:
|
||||
agent_id = sub_agent_info["agent_id"]
|
||||
|
||||
# 加载子 Agent
|
||||
sub_release = db.get(AppRelease, uuid.UUID(agent_id))
|
||||
if sub_release:
|
||||
sub_app = db.get(App, sub_release.app_id)
|
||||
sub_config_data = sub_release.config or {}
|
||||
sub_agent_config = AgentConfigProxy(sub_release, sub_app, sub_config_data)
|
||||
|
||||
sub_agents[agent_id] = {
|
||||
"config": sub_agent_config,
|
||||
"info": sub_agent_info
|
||||
}
|
||||
|
||||
# 5. 初始化 Master Agent 路由器
|
||||
state_manager = ConversationStateManager()
|
||||
router = MasterAgentRouter(
|
||||
db=db,
|
||||
master_agent_config=master_agent_config,
|
||||
master_model_config=master_model_config,
|
||||
sub_agents=sub_agents,
|
||||
state_manager=state_manager,
|
||||
enable_rule_fast_path=True
|
||||
)
|
||||
|
||||
# 6. 执行路由决策
|
||||
try:
|
||||
decision = await router.route(
|
||||
message=request.message,
|
||||
conversation_id=str(request.conversation_id) if request.conversation_id else None,
|
||||
variables=None
|
||||
)
|
||||
|
||||
# 7. 获取选中的 Agent 信息
|
||||
agent_id = decision["selected_agent_id"]
|
||||
agent_info = sub_agents.get(agent_id, {}).get("info", {})
|
||||
|
||||
# 8. 构建响应
|
||||
response_data = {
|
||||
"message": request.message,
|
||||
"master_agent": {
|
||||
"name": master_agent_config.name,
|
||||
"model": master_model_config.name
|
||||
},
|
||||
"decision": {
|
||||
"selected_agent_id": agent_id,
|
||||
"selected_agent_name": agent_info.get("name", "未知"),
|
||||
"selected_agent_role": agent_info.get("role", ""),
|
||||
"confidence": decision["confidence"],
|
||||
"reasoning": decision.get("reasoning", ""),
|
||||
"topic": decision.get("topic", ""),
|
||||
"strategy": decision["strategy"],
|
||||
"routing_method": decision.get("routing_method", ""),
|
||||
"need_collaboration": decision.get("need_collaboration", False),
|
||||
"collaboration_agents": decision.get("collaboration_agents", [])
|
||||
},
|
||||
"config_info": {
|
||||
"total_sub_agents": len(sub_agents),
|
||||
"enable_rule_fast_path": True
|
||||
}
|
||||
}
|
||||
|
||||
return success(
|
||||
data=response_data,
|
||||
msg="Master Agent 决策测试成功"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent 决策测试失败: {str(e)}")
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/batch-test-routing",
|
||||
summary="批量测试智能路由"
|
||||
)
|
||||
async def batch_test_routing(
|
||||
|
||||
@@ -5,9 +5,10 @@ import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import jwt
|
||||
from app.services import task_service, workspace_service
|
||||
from typing import Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -21,8 +22,10 @@ from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.repositories.app_repository import AppRepository
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -95,7 +98,7 @@ def get_access_token(
|
||||
access_token = create_access_token(user_id, share_token)
|
||||
|
||||
logger.info(
|
||||
f"生成访问 token",
|
||||
"生成访问 token",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"user_id": user_id
|
||||
@@ -270,7 +273,7 @@ def get_conversation(
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -313,6 +316,45 @@ async def chat(
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
|
||||
|
||||
appid=share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
# 获取应用类型
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
@@ -339,7 +381,7 @@ async def chat(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"参数验证完成",
|
||||
"参数验证完成",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"app_type": app_type,
|
||||
@@ -365,7 +407,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -388,7 +432,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -403,7 +449,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -426,7 +474,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
|
||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
@@ -62,7 +62,7 @@ Answer: Let's think step by step."""
|
||||
raise
|
||||
|
||||
|
||||
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/embedding/{model_id}", response_model=ApiResponse)
|
||||
def test_embedding(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
@@ -96,7 +96,7 @@ def test_embedding(
|
||||
return success(msg="测试LLM成功")
|
||||
|
||||
|
||||
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/rerank/{model_id}", response_model=ApiResponse)
|
||||
def test_rerank(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
|
||||
@@ -73,7 +73,7 @@ def get_workspaces(
|
||||
if not include_current and current_user.current_workspace_id:
|
||||
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
|
||||
api_logger.debug(
|
||||
f"过滤掉当前工作空间",
|
||||
"过滤掉当前工作空间",
|
||||
extra={"current_workspace_id": str(current_user.current_workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.agent.agent_chat import Agent_chat
|
||||
from app.core.logging_config import get_business_logger
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app.dependencies import workspace_access_guard
|
||||
from app.services.agent_server import config,ChatRequest
|
||||
router = APIRouter(prefix="/Test", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
class CombinedRequest(BaseModel):
|
||||
config_base: config
|
||||
agent_config: ChatRequest
|
||||
|
||||
@router.post("", summary="uuid")
|
||||
async def agent_chat(
|
||||
config_base: CombinedRequest
|
||||
):
|
||||
chat_config=config_base.agent_config
|
||||
chat_base=config_base.config_base
|
||||
request = ChatRequest(
|
||||
end_user_id=chat_config.end_user_id,
|
||||
message=chat_config.message,
|
||||
search_switch=chat_config.search_switch,
|
||||
kb_ids=chat_config.kb_ids,
|
||||
similarity_threshold=chat_config.similarity_threshold,
|
||||
vector_similarity_weight=chat_config.vector_similarity_weight,
|
||||
top_k=chat_config.top_k,
|
||||
hybrid=chat_config.hybrid,
|
||||
token=chat_config.token
|
||||
)
|
||||
|
||||
chat_result=await Agent_chat(chat_base).chat(request)
|
||||
|
||||
return chat_result
|
||||
@@ -1,109 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.api_resquests_server import messages_type, write_messages
|
||||
from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval
|
||||
|
||||
logger = get_business_logger()
|
||||
class Agent_chat:
|
||||
def __init__(self,config_data: dict):
|
||||
self.prompt_message = render_prompt_message(
|
||||
config_data.template_str,
|
||||
PromptMessageRole.USER,
|
||||
config_data.params
|
||||
)
|
||||
self.prompt = self.prompt_message.get_text_content()
|
||||
self.model_configs = config_data.model_configs
|
||||
self.history_memory = config_data.history_memory
|
||||
self.knowledge_base = config_data.knowledge_base
|
||||
logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" )
|
||||
|
||||
async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str):
|
||||
response = agent.invoke(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt
|
||||
}
|
||||
]
|
||||
},
|
||||
{"configurable": {"thread_id": f'{model_name}_{end_user_id}'}},
|
||||
)
|
||||
outputs = []
|
||||
for msg in response["messages"]:
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
outputs.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"name": t["name"], "arguments": t["args"]}
|
||||
for t in msg.tool_calls
|
||||
]
|
||||
})
|
||||
elif hasattr(msg, "content") and msg.content:
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"content": msg.content
|
||||
})
|
||||
ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"]
|
||||
return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages}
|
||||
|
||||
async def chat(self,req: ChatRequest) -> Dict[str, Any]:
|
||||
|
||||
end_user_id = req.end_user_id # 用 user_id 作为对话线程标识
|
||||
start=time.time()
|
||||
user_prompt = req.message
|
||||
|
||||
'''判断是都写入redis数据库'''
|
||||
messags_type = await messages_type(req.message,end_user_id)
|
||||
messags_type=messags_type['data']
|
||||
if messags_type=='question':
|
||||
writer_result=await write_messages(f'{end_user_id}', req.message)
|
||||
logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}')
|
||||
|
||||
|
||||
|
||||
'''history_memory'''
|
||||
|
||||
if self.history_memory==True:
|
||||
tool_result =await tool_memory(req)
|
||||
if tool_result!='' :tool_result=tool_result['data']
|
||||
if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息'''
|
||||
logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}")
|
||||
|
||||
'''baidu'''
|
||||
|
||||
|
||||
'''knowledge_base'''
|
||||
if self.knowledge_base == True:
|
||||
retrieval_result=await tool_Retrieval(req)
|
||||
retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']]
|
||||
retrieval_knowledge=','.join(retrieval_knowledge)
|
||||
logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}")
|
||||
if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果"
|
||||
self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语'
|
||||
logger.info(f"用户输入:{user_prompt}")
|
||||
logger.info(f"系统prompt:{self.prompt}")
|
||||
|
||||
AGENTS = {
|
||||
cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token)
|
||||
for cfg in self.model_configs
|
||||
}
|
||||
tasks=[
|
||||
self.run_agent(agent, end_user_id, user_prompt, model_name)
|
||||
for model_name, agent in AGENTS.items()
|
||||
]
|
||||
# 并行运行
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
result=[]
|
||||
|
||||
for i in results:
|
||||
result.append(i)
|
||||
chat_result=(f"最终耗时:{time.time()-start},{result}")
|
||||
return chat_result
|
||||
@@ -15,6 +15,8 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, Base
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.core.memory.agent.mcp_server.services import session_service
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -89,7 +91,7 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LangChain Agent 初始化完成",
|
||||
"LangChain Agent 初始化完成",
|
||||
extra={
|
||||
"model": model_name,
|
||||
"provider": provider,
|
||||
@@ -139,6 +141,42 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
session_id = store.save_session(
|
||||
userid=end_user_end,
|
||||
messages=messages,
|
||||
apply_id=end_user_end,
|
||||
group_id=end_user_end,
|
||||
aimessages=aimessages
|
||||
)
|
||||
store.delete_duplicate_sessions()
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
return session_id
|
||||
async def term_memory_redis_read(self,end_user_end):
|
||||
end_user_end = f"Term_{end_user_end}"
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
return messagss_list
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -149,6 +187,7 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -160,29 +199,29 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
start_time = time.time()
|
||||
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
if config_id==None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:actual_config_id=config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
|
||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory=';'.join(history_term_memory)
|
||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
|
||||
logger.debug(
|
||||
f"准备调用 LangChain Agent",
|
||||
"准备调用 LangChain Agent",
|
||||
extra={
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
@@ -203,15 +242,9 @@ class LangChainAgent:
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
if memory_flag:
|
||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
||||
await self.term_memory_save(message_chat,end_user_id,content)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -224,7 +257,7 @@ class LangChainAgent:
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Agent 调用完成",
|
||||
"Agent 调用完成",
|
||||
extra={
|
||||
"elapsed_time": elapsed_time,
|
||||
"content_length": len(response["content"])
|
||||
@@ -234,7 +267,7 @@ class LangChainAgent:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 调用失败", extra={"error": str(e)})
|
||||
logger.error("Agent 调用失败", extra={"error": str(e)})
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
@@ -246,7 +279,7 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
|
||||
memory_flag: Optional[bool] = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -259,28 +292,27 @@ class LangChainAgent:
|
||||
str: 消息内容块
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(f" chat_stream 方法开始执行")
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
logger.info(f" Message: {message[:100]}")
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
message_chat = message
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
if config_id==None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:actual_config_id=config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||
actual_config_id = config_id
|
||||
|
||||
try:
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -294,7 +326,7 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
|
||||
full_content=''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
@@ -307,6 +339,7 @@ class LangChainAgent:
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
@@ -316,6 +349,7 @@ class LangChainAgent:
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
@@ -329,6 +363,9 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
if memory_flag:
|
||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
@@ -341,7 +378,7 @@ class LangChainAgent:
|
||||
raise
|
||||
finally:
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"chat_stream 方法执行结束")
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
228
api/app/core/api_key_auth.py
Normal file
228
api/app/core/api_key_auth.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_utils import add_rate_limit_headers
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
RateLimitException,
|
||||
)
|
||||
from app.repositories.api_key_repository import ApiKeyLogRepository, ApiKeyRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services.api_key_service import ApiKeyAuthService, RateLimiterService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
def require_api_key(
|
||||
scopes: Optional[List[str]] = None,
|
||||
resource_type: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
API Key 鉴权装饰器
|
||||
|
||||
Args:
|
||||
scopes: 所需的权限范围列表["app:all",
|
||||
"rag:search", "rag:upload", "rag:delete",
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"]
|
||||
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
|
||||
|
||||
Usage:
|
||||
@router.get("/app/{resource_id}/chat")
|
||||
@require_api_key(scopes=["app:all"], resource_type="Agent")
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
api_key_auth: ApiKeyAuth = Depends(),
|
||||
db: Session = Depends(get_db),
|
||||
message: str
|
||||
):
|
||||
# api_key_auth 包含验证后的API Key 信息
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request = kwargs.get("request")
|
||||
db: Session = kwargs.get("db")
|
||||
|
||||
api_key = extract_api_key_from_request(request)
|
||||
if not api_key:
|
||||
logger.warning("API Key 缺失", extra={
|
||||
"endpoint": str(request.url),
|
||||
"method": request.method,
|
||||
"ip_address": request.client.host if request.client else None
|
||||
})
|
||||
raise BusinessException("API Key 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
api_key_obj = ApiKeyAuthService.validate_api_key(db, api_key)
|
||||
if not api_key_obj:
|
||||
logger.warning("API Key 无效或已过期", extra={
|
||||
"key_prefix": api_key[:10] + "..." if len(api_key) > 10 else api_key,
|
||||
"endpoint": str(request.url),
|
||||
"method": request.method,
|
||||
"ip_address": request.client.host if request.client else None
|
||||
})
|
||||
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"endpoint": str(request.url),
|
||||
"method": request.method,
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
raise RateLimitException(
|
||||
error_msg,
|
||||
code,
|
||||
rate_headers=rate_headers
|
||||
)
|
||||
|
||||
if scopes:
|
||||
missing_scopes = []
|
||||
for scope in scopes:
|
||||
if not ApiKeyAuthService.check_scope(api_key_obj, scope):
|
||||
missing_scopes.append(scope)
|
||||
if missing_scopes:
|
||||
logger.warning("API Key 权限不足", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"missing_scopes": missing_scopes,
|
||||
"available_scopes": api_key_obj.scopes,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
raise BusinessException(
|
||||
f"缺少必须的权限范围:{','.join(missing_scopes)}",
|
||||
BizCode.API_KEY_INVALID_SCOPE,
|
||||
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_type,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_type": resource_type,
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_type": resource_type,
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
|
||||
}
|
||||
)
|
||||
|
||||
kwargs["api_key_auth"] = ApiKeyAuth(
|
||||
api_key_id=api_key_obj.id,
|
||||
workspace_id=api_key_obj.workspace_id,
|
||||
type=api_key_obj.type,
|
||||
scopes=api_key_obj.scopes,
|
||||
resource_id=api_key_obj.resource_id,
|
||||
resource_type=api_key_obj.resource_type
|
||||
)
|
||||
|
||||
response = await func(*args, **kwargs)
|
||||
response = add_rate_limit_headers(response, rate_headers)
|
||||
|
||||
asyncio.create_task(log_api_key_usage(
|
||||
db, api_key_obj.id, request, response
|
||||
))
|
||||
return response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def extract_api_key_from_request(request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API Key
|
||||
|
||||
支持以下方式:
|
||||
1. Authorization: Bearer <api_key>
|
||||
2. X-API-Key: <api_key>
|
||||
"""
|
||||
try:
|
||||
# 从 Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
logger.warning("无效的 Authorization header 格式", extra={
|
||||
"auth_header": auth_header[:20] + "..." if len(auth_header) > 20 else auth_header,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return None
|
||||
auth_scheme, auth_token = auth_header.split(" ", 1)
|
||||
if auth_scheme.lower() != "bearer":
|
||||
logger.warning("无效的认证方案", extra={
|
||||
"auth_scheme": auth_scheme,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return None
|
||||
return auth_token
|
||||
|
||||
# 从 X-API-Key header
|
||||
api_key_header = request.headers.get("X-API-Key")
|
||||
if api_key_header:
|
||||
return api_key_header
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"提取 API Key 时发生错误: {str(e)}", extra={
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return None
|
||||
|
||||
|
||||
async def log_api_key_usage(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
request: Request,
|
||||
response: Response
|
||||
):
|
||||
"""记录 API Key 使用日志"""
|
||||
try:
|
||||
log_data = {
|
||||
"id": uuid.uuid4(),
|
||||
"api_key_id": api_key_id,
|
||||
"endpoint": str(request.url.path),
|
||||
"method": request.method,
|
||||
"ip_address": request.client.host if request.client else None,
|
||||
"user_agent": request.headers.get("User-Agent"),
|
||||
"status_code": response.status_code if hasattr(response, "status_code") else None,
|
||||
"response_time": None, # 需要在 middleware 中计算
|
||||
"tokens_used": None, # 需要从响应中提取
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
|
||||
ApiKeyLogRepository.create(db, log_data)
|
||||
ApiKeyRepository.update_usage(db, api_key_id)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"未能记录API密钥的使用情况: {e}")
|
||||
@@ -1,11 +1,35 @@
|
||||
"""API Key 工具函数"""
|
||||
import secrets
|
||||
import hashlib
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.api_key_schema import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class ResourceType:
|
||||
"""资源类型常量"""
|
||||
AGENT = "Agent"
|
||||
CLUSTER = "Cluster"
|
||||
WORKFLOW = "Workflow"
|
||||
KNOWLEDGE = "Knowledge"
|
||||
MEMORY_ENGINE = "Memory_Engine"
|
||||
|
||||
@classmethod
|
||||
def get_all_types(cls) -> list[str]:
|
||||
"""获取所有支持的资源类型"""
|
||||
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
|
||||
|
||||
@classmethod
|
||||
def is_valid_type(cls, resource_type: str) -> bool:
|
||||
"""验证资源类型是否有效"""
|
||||
return resource_type in cls.get_all_types()
|
||||
|
||||
|
||||
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
"""生成 API Key
|
||||
"""
|
||||
生成 API Key
|
||||
|
||||
Args:
|
||||
key_type: API Key 类型
|
||||
@@ -18,16 +42,15 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
ApiKeyType.APP: "sk-app-",
|
||||
ApiKeyType.RAG: "sk-rag-",
|
||||
ApiKeyType.MEMORY: "sk-mem-",
|
||||
ApiKeyType.GENERAL: "sk-gen-",
|
||||
}
|
||||
|
||||
|
||||
prefix = prefix_map[key_type]
|
||||
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
||||
api_key = f"{prefix}{random_string}"
|
||||
|
||||
|
||||
# 生成哈希值存储
|
||||
key_hash = hash_api_key(api_key)
|
||||
|
||||
|
||||
return api_key, key_hash, prefix
|
||||
|
||||
|
||||
@@ -44,7 +67,8 @@ def hash_api_key(api_key: str) -> str:
|
||||
|
||||
|
||||
def verify_api_key(api_key: str, key_hash: str) -> bool:
|
||||
"""验证 API Key
|
||||
"""
|
||||
验证 API Key
|
||||
|
||||
Args:
|
||||
api_key: API Key 明文
|
||||
@@ -53,4 +77,77 @@ def verify_api_key(api_key: str, key_hash: str) -> bool:
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
return hash_api_key(api_key) == key_hash
|
||||
computed_hash = hash_api_key(api_key)
|
||||
return secrets.compare_digest(computed_hash, key_hash)
|
||||
|
||||
|
||||
def validate_resource_binding(
|
||||
resource_type: Optional[str],
|
||||
resource_id: Optional[str]
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
验证资源绑定的有效性
|
||||
|
||||
Args:
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
|
||||
Returns:
|
||||
tuple: (是否有效, 错误信息)
|
||||
"""
|
||||
# 如果都为空,表示不绑定资源,这是有效的
|
||||
if not resource_type and not resource_id:
|
||||
return True, ""
|
||||
|
||||
# 如果只有一个为空,这是无效的
|
||||
if not resource_type or not resource_id:
|
||||
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
|
||||
|
||||
# 验证资源类型是否支持
|
||||
if not ResourceType.is_valid_type(resource_type):
|
||||
valid_types = ", ".join(ResourceType.get_all_types())
|
||||
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_resource_scope_mapping() -> dict[str, list[str]]:
|
||||
"""
|
||||
获取资源类型与权限范围的映射关系
|
||||
|
||||
Returns:
|
||||
dict: 资源类型到推荐权限范围的映射
|
||||
"""
|
||||
return {
|
||||
ResourceType.AGENT: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.CLUSTER: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.WORKFLOW: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.KNOWLEDGE: [
|
||||
"rag:search", "rag:upload", "rag:delete"
|
||||
],
|
||||
ResourceType.MEMORY_ENGINE: [
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def add_rate_limit_headers(response, headers: dict):
|
||||
"""统一添加限流响应头"""
|
||||
if isinstance(response, Response):
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
elif isinstance(response, JSONResponse):
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
elif hasattr(response, 'headers'):
|
||||
response.headers.update(headers)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class CompensationHandler:
|
||||
for compensation in reversed(self._compensations):
|
||||
try:
|
||||
compensation()
|
||||
logger.debug(f"Compensation operation executed successfully")
|
||||
logger.debug("Compensation operation executed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"补偿操作失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class Settings:
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://127.0.0.1:7687")
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
@@ -23,6 +23,11 @@ class Settings:
|
||||
DB_USER: str = os.getenv("DB_USER", "postgres")
|
||||
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "password")
|
||||
DB_NAME: str = os.getenv("DB_NAME", "redbear-mem")
|
||||
DB_POOL_SIZE: int = int(os.getenv("DB_POOL_SIZE", "50"))
|
||||
DB_MAX_OVERFLOW: int = int(os.getenv("DB_MAX_OVERFLOW", "20"))
|
||||
DB_POOL_RECYCLE: int = int(os.getenv("DB_POOL_RECYCLE", "1800"))
|
||||
DB_POOL_TIMEOUT: int = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
DB_POOL_PRE_PING: bool = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||
|
||||
DB_AUTO_UPGRADE = os.getenv("DB_AUTO_UPGRADE", "false").lower() == "true"
|
||||
|
||||
|
||||
@@ -19,6 +19,17 @@ class BizCode(IntEnum):
|
||||
TENANT_NOT_FOUND = 3002
|
||||
WORKSPACE_NO_ACCESS = 3003
|
||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||
# API Key 管理(3xxx)
|
||||
API_KEY_NOT_FOUND = 3007
|
||||
API_KEY_DUPLICATE_NAME = 3008
|
||||
API_KEY_INVALID = 3009
|
||||
API_KEY_EXPIRED = 3010
|
||||
API_KEY_INACTIVE = 3011
|
||||
API_KEY_INVALID_SCOPE = 3012
|
||||
API_KEY_INVALID_RESOURCE = 3013
|
||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||
API_KEY_QUOTA_EXCEEDED = 3016
|
||||
# 资源(4xxx)
|
||||
NOT_FOUND = 4000
|
||||
USER_NOT_FOUND = 4001
|
||||
@@ -112,6 +123,19 @@ HTTP_MAPPING = {
|
||||
BizCode.EMBED_NOT_ALLOWED: 403,
|
||||
BizCode.PERMISSION_DENIED: 403,
|
||||
BizCode.INVALID_CONVERSATION: 400,
|
||||
|
||||
# API Key 错误码映射
|
||||
BizCode.API_KEY_NOT_FOUND: 400,
|
||||
BizCode.API_KEY_DUPLICATE_NAME: 400,
|
||||
BizCode.API_KEY_INVALID: 401,
|
||||
BizCode.API_KEY_EXPIRED: 401,
|
||||
BizCode.API_KEY_INACTIVE: 401,
|
||||
BizCode.API_KEY_INVALID_SCOPE: 403,
|
||||
BizCode.API_KEY_INVALID_RESOURCE: 403,
|
||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
|
||||
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
|
||||
|
||||
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||
BizCode.API_KEY_MISSING: 400,
|
||||
BizCode.PROVIDER_NOT_SUPPORTED: 400,
|
||||
|
||||
@@ -83,4 +83,21 @@ class PermissionDeniedException(BusinessException):
|
||||
"""权限拒绝异常"""
|
||||
|
||||
def __init__(self, message: str = "权限不足", **kwargs):
|
||||
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
|
||||
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
|
||||
|
||||
|
||||
class RateLimitException(BusinessException):
|
||||
"""限流异常"""
|
||||
|
||||
def __init__(self, message: str, code: BizCode = None, rate_headers: dict = None, **kwargs):
|
||||
# 如果没有指定错误码,默认使用通用限流错误码
|
||||
if code is None:
|
||||
code = BizCode.RATE_LIMITED
|
||||
|
||||
# 将限流头信息添加到上下文中
|
||||
context = kwargs.get("context", {})
|
||||
if rate_headers:
|
||||
context["rate_limit_headers"] = rate_headers
|
||||
kwargs["context"] = context
|
||||
|
||||
super().__init__(message, code, **kwargs)
|
||||
@@ -210,7 +210,7 @@ class ProblemExtensionNode:
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name=='Input_Summary':
|
||||
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
|
||||
tool_call =re.findall("'id': '(.*?)'",str(last_message))[0]
|
||||
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
# try:
|
||||
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Agent logger module for backward compatibility.
|
||||
|
||||
This module maintains the get_named_logger() function for backward compatibility
|
||||
while delegating to the centralized logging configuration.
|
||||
|
||||
All new code should import directly from app.core.logging_config instead.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "RED_BEAR"
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
|
||||
def get_named_logger(name):
|
||||
"""Get a named logger for agent operations.
|
||||
|
||||
This function maintains backward compatibility with existing code.
|
||||
It delegates to the centralized get_agent_logger() function.
|
||||
|
||||
Args:
|
||||
name: Logger name for namespacing
|
||||
|
||||
Returns:
|
||||
Logger configured for agent operations
|
||||
|
||||
Example:
|
||||
>>> logger = get_named_logger("my_agent")
|
||||
>>> logger.info("Agent operation started")
|
||||
"""
|
||||
return get_agent_logger(name)
|
||||
@@ -144,13 +144,15 @@ def main():
|
||||
import asyncio
|
||||
tools_list = asyncio.run(mcp.list_tools())
|
||||
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
|
||||
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||
|
||||
# Run the server with SSE transport for HTTP connections
|
||||
# The server will be available at http://127.0.0.1:8081
|
||||
import uvicorn
|
||||
app = mcp.sse_app()
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||
|
||||
@@ -66,80 +66,27 @@ class ParameterBuilder:
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name == "Verify":
|
||||
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
|
||||
# Verify expects dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve":
|
||||
# Retrieve expects dict context + search_switch
|
||||
|
||||
elif tool_name in ["Retrieve"]:
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
"search_switch": search_switch,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name in ["Summary", "Summary_fails"]:
|
||||
# Summary tools expect JSON string context
|
||||
if isinstance(content, dict):
|
||||
context_str = json.dumps(content, ensure_ascii=False)
|
||||
elif isinstance(content, str):
|
||||
context_str = content
|
||||
else:
|
||||
context_str = json.dumps({"data": content}, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"context": context_str,
|
||||
**base_args
|
||||
}
|
||||
|
||||
elif tool_name == "Retrieve_Summary":
|
||||
# Retrieve_Summary needs to unwrap nested context structures
|
||||
# Handle both 'content' and 'context' keys
|
||||
context_dict = content
|
||||
|
||||
if isinstance(content, dict):
|
||||
# Check for nested 'content' wrapper
|
||||
if "content" in content:
|
||||
inner = content["content"]
|
||||
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
|
||||
# Check for 'context' wrapper
|
||||
elif "context" in content:
|
||||
context_dict = content["context"] if isinstance(content["context"], dict) else content
|
||||
|
||||
return {
|
||||
"context": context_dict,
|
||||
**base_args
|
||||
}
|
||||
|
||||
|
||||
elif tool_name == "Input_Summary":
|
||||
# Input_Summary expects raw message string + search_switch
|
||||
# Content should be the raw message string
|
||||
if isinstance(content, dict):
|
||||
# Try to extract message from dict
|
||||
message_str = content.get("sentence", str(content))
|
||||
else:
|
||||
message_str = str(content)
|
||||
|
||||
|
||||
return {
|
||||
"context": message_str,
|
||||
"search_switch": search_switch,
|
||||
|
||||
@@ -116,7 +116,7 @@ async def Split_The_Problem(
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info(f"问题拆分")
|
||||
logger.info("问题拆分")
|
||||
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
@@ -250,7 +250,7 @@ async def Problem_Extension(
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info(f"问题扩展")
|
||||
logger.info("问题扩展")
|
||||
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
|
||||
@@ -167,7 +167,7 @@ async def Retrieve(
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val):
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
|
||||
@@ -73,16 +73,16 @@ async def Summary(
|
||||
answer_small, query = await Summary_messages_deal(context)
|
||||
|
||||
|
||||
# Get conversation history
|
||||
start_time= time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
# Prepare data for template
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": answer_small
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary: initialization failed: {e}",
|
||||
@@ -92,7 +92,7 @@ async def Summary(
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
@@ -110,23 +110,23 @@ async def Summary(
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=SummaryResponse
|
||||
)
|
||||
|
||||
|
||||
aimessages = structured.query_answer or ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
|
||||
|
||||
try:
|
||||
# Save session
|
||||
if aimessages != "":
|
||||
@@ -147,16 +147,16 @@ async def Summary(
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
|
||||
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
@@ -164,7 +164,7 @@ async def Summary(
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('总结', duration)
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
@@ -185,7 +185,7 @@ async def Retrieve_Summary(
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||
@@ -194,23 +194,23 @@ async def Retrieve_Summary(
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
@@ -219,13 +219,13 @@ async def Retrieve_Summary(
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
logger.info(f"Retrieve_Summary: successfully parsed JSON")
|
||||
logger.info("Retrieve_Summary: successfully parsed JSON")
|
||||
except json.JSONDecodeError:
|
||||
# Try unescaping first
|
||||
try:
|
||||
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||
parsed = json.loads(unescaped)
|
||||
logger.info(f"Retrieve_Summary: parsed after unescaping")
|
||||
logger.info("Retrieve_Summary: parsed after unescaping")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||
@@ -249,10 +249,10 @@ async def Retrieve_Summary(
|
||||
context_dict = context
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
@@ -263,7 +263,7 @@ async def Retrieve_Summary(
|
||||
answer = item["Answer_Small"]
|
||||
elif "Answer_Samll" in item:
|
||||
answer = item["Answer_Samll"]
|
||||
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
if isinstance(answer, list):
|
||||
@@ -273,14 +273,15 @@ async def Retrieve_Summary(
|
||||
retrieve_info.append(answer)
|
||||
else:
|
||||
retrieve_info.append(str(answer))
|
||||
|
||||
|
||||
# Join all retrieve_info into a single string
|
||||
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||
|
||||
# Get conversation history
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: initialization failed: {e}",
|
||||
@@ -290,7 +291,7 @@ async def Retrieve_Summary(
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
@@ -309,14 +310,14 @@ async def Retrieve_Summary(
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
|
||||
|
||||
# Handle case where structured response might be None or incomplete
|
||||
if structured and hasattr(structured, 'data') and structured.data:
|
||||
aimessages = structured.data.query_answer or ""
|
||||
@@ -324,7 +325,7 @@ async def Retrieve_Summary(
|
||||
logger.warning("Structured response is None or incomplete, using default message")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
|
||||
# Check for insufficient information response
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||
# Save session
|
||||
@@ -344,13 +345,13 @@ async def Retrieve_Summary(
|
||||
aimessages = ""
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
|
||||
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
@@ -358,7 +359,7 @@ async def Retrieve_Summary(
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索总结', duration)
|
||||
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -388,7 +389,7 @@ async def Input_Summary(
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: String containing the input sentence
|
||||
@@ -398,44 +399,46 @@ async def Input_Summary(
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with the summary result
|
||||
"""
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
|
||||
# Check if llm_client is None
|
||||
if llm_client is None:
|
||||
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
|
||||
# Get conversation history
|
||||
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(
|
||||
str(sessionid),
|
||||
str(apply_id),
|
||||
str(group_id)
|
||||
)
|
||||
end_time=time.time()
|
||||
logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}")
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
|
||||
# Log the raw context for debugging
|
||||
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||
|
||||
|
||||
# Extract sentence from context
|
||||
# Context can be a string or might contain the sentence in various formats
|
||||
try:
|
||||
@@ -457,23 +460,23 @@ async def Input_Summary(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract query from context: {e}")
|
||||
query = context
|
||||
|
||||
|
||||
# Clean query
|
||||
query = str(query).strip().strip("\"'")
|
||||
|
||||
|
||||
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||
|
||||
|
||||
# Execute search based on search_switch and storage_type
|
||||
try:
|
||||
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||
|
||||
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
'''检索'''
|
||||
@@ -509,10 +512,10 @@ async def Input_Summary(
|
||||
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
logger.info(f"Input_Summary: 使用 summary 进行检索")
|
||||
logger.info("Input_Summary: 使用 summary 进行检索")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||
@@ -520,7 +523,7 @@ async def Input_Summary(
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
@@ -529,7 +532,7 @@ async def Input_Summary(
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
@@ -543,9 +546,9 @@ async def Input_Summary(
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -563,7 +566,7 @@ async def Input_Summary(
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary failed: {e}",
|
||||
@@ -576,7 +579,7 @@ async def Input_Summary(
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
@@ -599,7 +602,7 @@ async def Summary_fails(
|
||||
) -> dict:
|
||||
"""
|
||||
Handle workflow failure when summary cannot be generated.
|
||||
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Failure context string
|
||||
@@ -608,22 +611,22 @@ async def Summary_fails(
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with failure message
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
|
||||
# Parse session ID from usermessages
|
||||
usermessages_parts = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages_parts[:-1])
|
||||
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
logger.info(f"没有相关数据")
|
||||
|
||||
logger.info("没有相关数据")
|
||||
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||
|
||||
return {
|
||||
|
||||
@@ -78,7 +78,7 @@ async def Verify(
|
||||
|
||||
# Build query list for verification
|
||||
query_list = []
|
||||
for query_small, anser in zip(Query_small, Result_small):
|
||||
for query_small, anser in zip(Query_small, Result_small, strict=False):
|
||||
query_list.append({
|
||||
'Query_small': query_small,
|
||||
'Answer_Small': anser
|
||||
|
||||
114
api/app/core/memory/agent/multimodal/oss_picture.py
Normal file
114
api/app/core/memory/agent/multimodal/oss_picture.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
|
||||
# from qcloud_cos import CosConfig, CosS3Client
|
||||
# from qcloud_cos.cos_exception import CosClientError, CosServiceError
|
||||
|
||||
# from config.paths import BASE_DIR
|
||||
BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
|
||||
class OSSUploader:
|
||||
"""对象存储文件上传工具类"""
|
||||
|
||||
def __init__(self, env):
|
||||
api = {
|
||||
"test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon",
|
||||
"prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon"
|
||||
}
|
||||
self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon")
|
||||
self.privacy = "false"
|
||||
self.headers = {
|
||||
"User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
|
||||
'AppleWebKit/537.36 (KHTML, like Gecko)'
|
||||
' Chrome/133.0.6833.84 Safari/537.36'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _generate_object_key(file_path, prefix='xhs_'):
|
||||
"""
|
||||
生成对象存储的Key
|
||||
|
||||
:param file_path: 本地文件路径
|
||||
:param prefix: 存储前缀,用于分类存储
|
||||
:return: 生成的对象Key
|
||||
"""
|
||||
# 文件md5值.后缀名
|
||||
filename = os.path.basename(file_path)
|
||||
filename = f"{filename}"
|
||||
|
||||
# 组合成完整的对象Key
|
||||
return f"{prefix}{filename}"
|
||||
|
||||
def upload_image(self, file_name, prefix='jd_'):
|
||||
"""
|
||||
上传文件到COS并返回可访问的URL
|
||||
|
||||
:param file_url: 文件路径
|
||||
:param file_name: 文件名称
|
||||
:param media_type: 文件类型
|
||||
:param prefix: 存储前缀,用于分类存储
|
||||
:return: 文件访问URL
|
||||
"""
|
||||
# 检查文件是否存在
|
||||
|
||||
|
||||
|
||||
file_path = os.path.join(BASE_DIR, file_name)
|
||||
|
||||
# response = requests.get(url, headers=self.headers, stream=True)
|
||||
|
||||
# if response.status_code == 200:
|
||||
# with open(file_path, "wb") as f:
|
||||
# for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大
|
||||
# f.write(chunk)
|
||||
# else:
|
||||
# raise Exception(f"文件下载失败,{file_name}")
|
||||
|
||||
# 生成对象Key
|
||||
object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1])
|
||||
|
||||
try:
|
||||
upload_response = requests.post(
|
||||
self.api,
|
||||
data={
|
||||
"privacy": self.privacy,
|
||||
"fileName": object_key,
|
||||
}
|
||||
)
|
||||
|
||||
if upload_response.status_code != 200:
|
||||
raise Exception('上传接口请求失败')
|
||||
resp = upload_response.json()
|
||||
name = resp["data"]["name"]
|
||||
file_url = resp["data"]["path"]
|
||||
policy = resp["data"]["policy"]
|
||||
with open(file_path, 'rb') as f:
|
||||
oss_push_resp = requests.post(
|
||||
policy["host"],
|
||||
files={
|
||||
"key": policy["dir"],
|
||||
"OSSAccessKeyId": policy["accessid"],
|
||||
"name": name,
|
||||
"policy": policy["policy"],
|
||||
"success_action_status": 200,
|
||||
"signature": policy["signature"],
|
||||
"file": f,
|
||||
}
|
||||
)
|
||||
if oss_push_resp.status_code == 200:
|
||||
return file_url
|
||||
raise Exception("OSS上传失败")
|
||||
except Exception:
|
||||
raise Exception(f"上传失败: \n{traceback.format_exc()}")
|
||||
finally:
|
||||
print('success')
|
||||
# os.remove(file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cos_uploader = OSSUploader("prod")
|
||||
url =cos_uploader.upload_image('./example01.jpg')
|
||||
print(url)
|
||||
121
api/app/core/memory/agent/multimodal/speech_model.py
Normal file
121
api/app/core/memory/agent/multimodal/speech_model.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
# file_urls = [
|
||||
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav",
|
||||
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav",
|
||||
# ]
|
||||
class Vico_recognition:
|
||||
def __init__(self,file_urls):
|
||||
self.api_key=''
|
||||
self.backend_model_name=''
|
||||
self.api_base=''
|
||||
self.file_urls=file_urls
|
||||
|
||||
# 提交文件转写任务,包含待转写文件url列表
|
||||
async def submit_task(self) -> str:
|
||||
self.api_key, self.backend_model_name, self.api_base =await Voice_recognize()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
data = {
|
||||
"model": self.backend_model_name,
|
||||
"input": {"file_urls": self.file_urls},
|
||||
"parameters": {
|
||||
"channel_id": [0],
|
||||
"vocabulary_id": "vocab-Xxxx",
|
||||
},
|
||||
}
|
||||
# 录音文件转写服务url
|
||||
service_url = (
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
|
||||
)
|
||||
response = requests.post(
|
||||
service_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
# 打印响应内容
|
||||
if response.status_code == 200:
|
||||
return response.json()["output"]["task_id"]
|
||||
else:
|
||||
print("task failed!")
|
||||
print(response.json())
|
||||
return None
|
||||
|
||||
async def download_transcription_result(self, transcription_url):
|
||||
"""
|
||||
Args:
|
||||
transcription_url (str): 转写结果文件URL
|
||||
Returns:
|
||||
dict: 转写结果内容
|
||||
"""
|
||||
try:
|
||||
response = requests.get(transcription_url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
print(f"下载转写结果失败: {e}")
|
||||
return None
|
||||
|
||||
# 循环查询任务状态直到成功
|
||||
async def wait_for_complete(self,task_id):
|
||||
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
|
||||
pending = True
|
||||
while pending:
|
||||
# 查询任务状态服务url
|
||||
service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||
response = requests.post(
|
||||
service_url, headers=headers
|
||||
)
|
||||
if response.status_code == 200:
|
||||
status = response.json()['output']['task_status']
|
||||
if status == 'SUCCEEDED':
|
||||
print("task succeeded!")
|
||||
pending = False
|
||||
return response.json()['output']['results']
|
||||
elif status == 'RUNNING' or status == 'PENDING':
|
||||
pass
|
||||
else:
|
||||
print("task failed!")
|
||||
pending = False
|
||||
else:
|
||||
print("query failed!")
|
||||
pending = False
|
||||
time.sleep(0.1)
|
||||
async def run(self):
|
||||
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
|
||||
task_id=await self.submit_task()
|
||||
result=await self.wait_for_complete(task_id)
|
||||
result_context=[]
|
||||
for i in result:
|
||||
transcription_url=i['transcription_url']
|
||||
print(f"转写URL: {transcription_url}")
|
||||
|
||||
# 下载并打印转写内容
|
||||
content = await self.download_transcription_result(transcription_url)
|
||||
if content:
|
||||
content=json.dumps(content, indent=2, ensure_ascii=False)
|
||||
context=re.findall(r'"text": "(.*?)"', content)
|
||||
result_context.append(context[0])
|
||||
result=''.join(result_context)
|
||||
return (result)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,31 +16,13 @@ from app.core.memory.utils.config.config_utils import get_picture_config, get_vo
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
#TODO: Refactor entire picture/voice
|
||||
# async def LLM_model_request(context,data,query):
|
||||
# '''
|
||||
# Agent model request
|
||||
# Args:
|
||||
# context:Input request
|
||||
# data: template parameters
|
||||
# query:request content
|
||||
# Returns:
|
||||
|
||||
# '''
|
||||
# template = Template(context)
|
||||
# system_prompt = template.render(**data)
|
||||
# llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
# result = await llm_client.chat(
|
||||
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
|
||||
# )
|
||||
# return result
|
||||
|
||||
async def picture_model_requests(image_url):
|
||||
'''
|
||||
@@ -106,33 +88,9 @@ class COUNTState:
|
||||
def reset(self):
|
||||
"""手动重置累加值"""
|
||||
self.total = 0
|
||||
print(f"[COUNTState] 已重置为 0")
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
|
||||
# def embed(texts: list[str]) -> list[list[float]]:
|
||||
# # 这里可以换成 LangChain Embeddings
|
||||
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
|
||||
|
||||
|
||||
# def export_store_to_json(store, namespace):
|
||||
# """Export the entire storage content to a JSON file"""
|
||||
# # 搜索所有存储项
|
||||
# all_items = store.search(namespace)
|
||||
|
||||
# # 整理数据
|
||||
# export_data = {}
|
||||
# for item in all_items:
|
||||
# if hasattr(item, 'key') and hasattr(item, 'value'):
|
||||
# export_data[item.key] = item.value
|
||||
|
||||
# # 保存到文件
|
||||
# os.makedirs("memory_logs", exist_ok=True)
|
||||
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
|
||||
# json.dump(export_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# print(f"{len(export_data)} 条记忆到 JSON 文件")
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
import os
|
||||
from app.core.config import settings
|
||||
|
||||
def get_mcp_server_config():
|
||||
"""
|
||||
Get the MCP server configuration
|
||||
Get the MCP server configuration.
|
||||
|
||||
Uses MCP_SERVER_URL environment variable if set (for Docker),
|
||||
otherwise falls back to SERVER_IP and MCP_PORT (for local development).
|
||||
"""
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = os.getenv("MCP_PORT", "8081")
|
||||
|
||||
# In Docker: MCP_SERVER_URL=http://mcp-server:8081
|
||||
# In local dev: uses SERVER_IP (127.0.0.1 or localhost)
|
||||
mcp_server_url = os.getenv("MCP_SERVER_URL")
|
||||
|
||||
if mcp_server_url:
|
||||
# Docker environment: use full URL from environment
|
||||
base_url = mcp_server_url
|
||||
else:
|
||||
# Local development: build URL from SERVER_IP and MCP_PORT
|
||||
base_url = f"http://{settings.SERVER_IP}:{mcp_port}"
|
||||
|
||||
mcp_server_config = {
|
||||
"data_flow": {
|
||||
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
|
||||
"url": f"{base_url}/sse",
|
||||
"transport": "sse",
|
||||
"timeout": 15000,
|
||||
"sse_read_timeout": 15000,
|
||||
|
||||
@@ -191,9 +191,9 @@ async def VerifyTool_messages_deal(context):
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
content_messages = messages.split('"context":')[1].replace('""', '"')
|
||||
messages = str(content_messages).split("name='Retrieve'")[0]
|
||||
query = re.findall(f'"Query": "(.*?)"', messages)[0]
|
||||
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
|
||||
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
|
||||
query = re.findall('"Query": "(.*?)"', messages)[0]
|
||||
Query_small = re.findall('"Query_small": "(.*?)"', messages)
|
||||
Result_small = re.findall('"Result_small": "(.*?)"', messages)
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ This module provides utilities for detecting and processing multimodal inputs
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
# TODO 后续更新
|
||||
# from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||
|
||||
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||
from app.core.memory.agent.utils.llm_tools import picture_model_requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -124,7 +124,7 @@ class MultimodalProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
|
||||
logger.info(f"[MultimodalProcessor] Falling back to original content")
|
||||
logger.info("[MultimodalProcessor] Falling back to original content")
|
||||
return content
|
||||
|
||||
# Return original content if not multimodal
|
||||
|
||||
@@ -2,24 +2,47 @@ import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
class RedisSessionStore:
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
|
||||
self.r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||
self.uudi=session_id
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
|
||||
# 使用 Hash 存储结构化数据
|
||||
result = self.r.hset(key, mapping={
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
@@ -28,12 +51,54 @@ class RedisSessionStore:
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
})
|
||||
print(f"保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"group_id": session.get('group_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
"""
|
||||
@@ -41,9 +106,7 @@ class RedisSessionStore:
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
data = self.r.hgetall(key)
|
||||
if data:
|
||||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
return None
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
@@ -52,21 +115,17 @@ class RedisSessionStore:
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
result_items.append(decoded_data)
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
@@ -76,7 +135,7 @@ class RedisSessionStore:
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.decode('utf-8').split(':')[1]
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
@@ -84,12 +143,14 @@ class RedisSessionStore:
|
||||
def update_session(self, session_id, field, value):
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
if self.r.exists(key):
|
||||
self.r.hset(key, field, value)
|
||||
return True
|
||||
return False
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
@@ -112,38 +173,67 @@ class RedisSessionStore:
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
"""
|
||||
seen = set() # 用来记录已出现的唯一组合
|
||||
deleted_count = 0
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值并解码
|
||||
sessionid = data.get(b'sessionid', b'').decode('utf-8')
|
||||
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
|
||||
group_id = data.get(b'group_id', b'').decode('utf-8')
|
||||
messages = data.get(b'messages', b'').decode('utf-8')
|
||||
aimessages = data.get(b'aimessages', b'').decode('utf-8')
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
group_id = data.get('group_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,删除该 key
|
||||
self.r.delete(key)
|
||||
deleted_count += 1
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,加入 seen
|
||||
seen.add(identifier)
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
pipe = self.r.pipeline()
|
||||
for key in batch:
|
||||
pipe.delete(key)
|
||||
pipe.execute()
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self,sessionid):
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
@@ -160,44 +250,62 @@ class RedisSessionStore:
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
result_items = []
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key_bytes in self.r.keys('session:*'):
|
||||
key = key_bytes.decode('utf-8')
|
||||
data = self.r.hgetall(key)
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||
# 检查是否符合三个条件
|
||||
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (decoded_data.get('sessionid') == sessionid and
|
||||
decoded_data.get('apply_id') == apply_id and
|
||||
decoded_data.get('group_id') == group_id):
|
||||
history = {
|
||||
"Query": decoded_data.get('messages'),
|
||||
"Answer": decoded_data.get('aimessages')
|
||||
}
|
||||
|
||||
|
||||
result_items.append(history)
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
@@ -44,7 +44,7 @@ class VerifyTool:
|
||||
async def model_1(self, state: State) -> State:
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
response_content = await llm_client.chat(
|
||||
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
|
||||
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
|
||||
)
|
||||
return {
|
||||
"agent1_response": response_content,
|
||||
|
||||
@@ -63,7 +63,7 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
# 获取 embedder 配置
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
|
||||
23
api/app/core/memory/analytics/__init__.py
Normal file
23
api/app/core/memory/analytics/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Memory Analytics Module
|
||||
|
||||
This module provides analytics and insights for the memory system.
|
||||
|
||||
Available functions:
|
||||
- get_hot_memory_tags: Get hot memory tags by frequency
|
||||
- MemoryInsight: Generate memory insight reports
|
||||
- get_recent_activity_stats: Get recent activity statistics
|
||||
- generate_user_summary: Generate user summary
|
||||
"""
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
|
||||
__all__ = [
|
||||
"get_hot_memory_tags",
|
||||
"MemoryInsight",
|
||||
"get_recent_activity_stats",
|
||||
"generate_user_summary",
|
||||
]
|
||||
198
api/app/core/memory/analytics/api_docs_parser.py
Normal file
198
api/app/core/memory/analytics/api_docs_parser.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
|
||||
def _parse_meta_block(md_text: str) -> Dict[str, Any]:
|
||||
sections: Dict[str, Any] = {}
|
||||
m = re.search(r"```javascript([\s\S]*?)```", md_text)
|
||||
if not m:
|
||||
return sections
|
||||
block = m.group(1)
|
||||
search_opts: List[Dict[str, Any]] = []
|
||||
status_codes: List[Dict[str, Any]] = []
|
||||
for line in block.splitlines():
|
||||
s = line.strip()
|
||||
if not s:
|
||||
continue
|
||||
msw = re.match(r"search_switch:?(\d+)\s*(?:((.*?)))?", s)
|
||||
if msw:
|
||||
val = msw.group(1)
|
||||
desc = msw.group(2) or ""
|
||||
search_opts.append({"value": val, "desc": desc})
|
||||
continue
|
||||
mcode = re.match(r"code:(\d+)\.\s*(.*)", s)
|
||||
if mcode:
|
||||
code = mcode.group(1)
|
||||
desc = mcode.group(2).strip()
|
||||
status_codes.append({"code": code, "desc": desc})
|
||||
continue
|
||||
if search_opts:
|
||||
sections["search_switch"] = search_opts
|
||||
if status_codes:
|
||||
sections["status_code"] = status_codes
|
||||
return sections
|
||||
|
||||
|
||||
def _extract_code_block(md_lines: List[str], start_idx: int) -> Tuple[str, int]:
|
||||
content_lines: List[str] = []
|
||||
i = start_idx
|
||||
while i < len(md_lines) and md_lines[i].strip() == "":
|
||||
i += 1
|
||||
if i >= len(md_lines):
|
||||
return "", i
|
||||
start_line = md_lines[i].strip()
|
||||
if not re.match(r"^`{3,}.*", start_line):
|
||||
return "", i
|
||||
i += 1
|
||||
while i < len(md_lines):
|
||||
line = md_lines[i]
|
||||
if re.match(r"^`{3,}.*", line.strip()):
|
||||
i += 1
|
||||
break
|
||||
content_lines.append(line)
|
||||
i += 1
|
||||
return "\n".join(content_lines).strip(), i
|
||||
|
||||
|
||||
def _parse_sections(md_text: str) -> List[Dict[str, Any]]:
|
||||
lines = md_text.splitlines()
|
||||
sections: List[Dict[str, Any]] = []
|
||||
i = 0
|
||||
current: Dict[str, Any] | None = None
|
||||
|
||||
def _clean_inline(s: str) -> str:
|
||||
s = s.strip()
|
||||
if s.startswith("`") and s.endswith("`"):
|
||||
s = s[1:-1]
|
||||
return s.strip()
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
if line.startswith("# ") and ":" in line:
|
||||
name = line.split(":", 1)[1].strip()
|
||||
current = {"name": name}
|
||||
sections.append(current)
|
||||
i += 1
|
||||
continue
|
||||
if current is not None and line.strip().startswith("### "):
|
||||
s = line.strip()
|
||||
if "请求端口" in s:
|
||||
m = re.search(r"请求端口(.*)$", s)
|
||||
if m:
|
||||
current["path"] = _clean_inline(m.group(1))
|
||||
i += 1
|
||||
continue
|
||||
if "请求方式" in s:
|
||||
m = re.search(r"请求方式[::](.*)$", s)
|
||||
if m:
|
||||
current["method"] = _clean_inline(m.group(1))
|
||||
i += 1
|
||||
continue
|
||||
if s.startswith("### 描述"):
|
||||
i += 1
|
||||
desc_lines: List[str] = []
|
||||
while i < len(lines):
|
||||
nl = lines[i]
|
||||
if nl.strip().startswith("### "):
|
||||
break
|
||||
if re.match(r"^`{3,}.*", nl.strip()):
|
||||
break
|
||||
desc_lines.append(nl.strip())
|
||||
i += 1
|
||||
current["desc"] = "\n".join([x for x in desc_lines if x]).strip() or None
|
||||
continue
|
||||
if s.startswith("### 输入"):
|
||||
if "无" in s:
|
||||
current["input"] = "无"
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
block, i = _extract_code_block(lines, i)
|
||||
current["input"] = block or None
|
||||
continue
|
||||
if s.startswith("### 输出"):
|
||||
i += 1
|
||||
block, i = _extract_code_block(lines, i)
|
||||
current["output"] = block or None
|
||||
continue
|
||||
if s.startswith("### 请求体参数"):
|
||||
i += 1
|
||||
params, i = _parse_body_params_table(lines, i)
|
||||
if params:
|
||||
current["body_params"] = params
|
||||
continue
|
||||
i += 1
|
||||
return sections
|
||||
|
||||
|
||||
def parse_api_docs(file_path: str) -> Dict[str, Any]:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(file_path)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
md_text = f.read()
|
||||
m = re.search(r"^#\s*(.+)$", md_text, re.M)
|
||||
title = m.group(1).strip() if m else ""
|
||||
meta = _parse_meta_block(md_text)
|
||||
sections = _parse_sections(md_text)
|
||||
return {"title": title, "meta": meta, "sections": sections}
|
||||
|
||||
|
||||
def get_default_docs_path() -> str:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
||||
|
||||
|
||||
def parse_api_docs_default() -> Dict[str, Any]:
|
||||
return parse_api_docs(get_default_docs_path())
|
||||
|
||||
|
||||
def get_section_descriptions(file_path: str) -> Dict[str, str]:
|
||||
data = parse_api_docs(file_path)
|
||||
out: Dict[str, str] = {}
|
||||
for s in data.get("sections", []):
|
||||
name = s.get("name")
|
||||
if not name:
|
||||
continue
|
||||
out[name] = s.get("desc") or ""
|
||||
return out
|
||||
|
||||
|
||||
def get_section_descriptions_default() -> Dict[str, str]:
|
||||
return get_section_descriptions(get_default_docs_path())
|
||||
|
||||
|
||||
def _parse_body_params_table(md_lines: List[str], start_idx: int) -> Tuple[List[Dict[str, Any]], int]:
|
||||
rows: List[str] = []
|
||||
i = start_idx
|
||||
while i < len(md_lines) and md_lines[i].strip() == "":
|
||||
i += 1
|
||||
if i >= len(md_lines) or not md_lines[i].strip().startswith("|"):
|
||||
return [], i
|
||||
header = md_lines[i].strip()
|
||||
i += 1
|
||||
if i < len(md_lines) and md_lines[i].strip().startswith("|"):
|
||||
i += 1
|
||||
while i < len(md_lines) and md_lines[i].strip().startswith("|"):
|
||||
rows.append(md_lines[i].strip())
|
||||
i += 1
|
||||
headers = [h.strip() for h in header.strip('|').split('|')]
|
||||
out: List[Dict[str, Any]] = []
|
||||
for r in rows:
|
||||
cols = [c.strip() for c in r.strip('|').split('|')]
|
||||
if len(cols) != len(headers):
|
||||
continue
|
||||
item: Dict[str, Any] = {}
|
||||
for k, v in zip(headers, cols, strict=False):
|
||||
if k == "参数名":
|
||||
item["name"] = v
|
||||
elif k == "类型":
|
||||
item["type"] = v
|
||||
elif k == "是否必填":
|
||||
item["required"] = v
|
||||
elif k == "描述":
|
||||
item["desc"] = v
|
||||
else:
|
||||
item[k] = v
|
||||
out.append(item)
|
||||
return out, i
|
||||
204
api/app/core/memory/analytics/hot_memory_tags.py
Normal file
204
api/app/core/memory/analytics/hot_memory_tags.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from neo4j import GraphDatabase
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ------------------- 自包含路径解析 -------------------
|
||||
# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。
|
||||
try:
|
||||
# 假设脚本在 /path/to/project/src/analytics/
|
||||
# 上升3个级别以到达项目根目录。
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
|
||||
# 将 'src' 和 'project_root' 都添加到路径中。
|
||||
# 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。
|
||||
# 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except NameError:
|
||||
# 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd()))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
# 现在路径已经配置好,我们可以使用绝对导入
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
import json
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
"""
|
||||
try:
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一位顶级的文本分析专家,任务是提炼、筛选并合并最具体、最核心的名词。你的目标是识别具体的事件、地点、物体或作品,并严格执行以下步骤:\n\n1. **筛选**: 严格过滤掉以下类型的词语:\n * **抽象概念或训练活动**: 任何描述抽象品质、训练项目或研究过程的词语(例如:'核心力量', '实际的历史研究', '团队合作')。\n * **动作或过程词**: 任何描述具体动作或过程的词语(例如:'打篮球', '快攻', '远投')。\n * **描述性短语**: 任何描述状态、关系或感受的短语(例如:'配合越来越默契')。\n * **过于宽泛的类别**: 过于笼统的分类(例如:'历史剧')。\n\n2. **合并**: 在筛选后,对语义相近或存在包含关系的词语进行合并,只保留最核心、最具代表性的一个。\n * 例如,在“篮球赛”和“篮球场”中,“篮球赛”是更核心的事件,应保留“篮球赛”。\n\n你的最终输出应该是一个精炼的、无重复概念的列表,只包含最具体、最具有代表性的名词。\n\n**示例**:\n输入: ['篮球赛', '篮球场', '核心力量', '实际的历史研究', '《二战全史》', '攀岩']\n筛选后: ['篮球赛', '篮球场', '《二战全史》', '攀岩']\n合并后最终输出: ['篮球赛', '《二战全史》', '攀岩']"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"请从以下标签列表中筛选出核心名词: {tag_list_str}"
|
||||
}
|
||||
]
|
||||
|
||||
# 调用LLM进行结构化输出
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=FilteredTags
|
||||
)
|
||||
|
||||
return structured_response.meaningful_tags
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM筛选过程中发生错误: {e}")
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
def get_db_connection():
|
||||
"""
|
||||
使用项目的标准配置方法建立与Neo4j数据库的连接。
|
||||
"""
|
||||
# 从全局配置获取 Neo4j 连接信息
|
||||
uri = settings.NEO4J_URI
|
||||
user = settings.NEO4J_USERNAME
|
||||
|
||||
# 密码必须为了安全从环境变量加载
|
||||
password = os.getenv("NEO4J_PASSWORD")
|
||||
|
||||
if not uri or not user:
|
||||
raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。")
|
||||
if not password:
|
||||
raise ValueError("NEO4J_PASSWORD 环境变量未设置。")
|
||||
|
||||
# 为此脚本使用同步驱动
|
||||
return GraphDatabase.driver(uri, auth=(user, password))
|
||||
|
||||
def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
从数据库查询原始的、未经过滤的实体标签及其频率。
|
||||
|
||||
Args:
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
"""
|
||||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||||
|
||||
if by_user:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
)
|
||||
else:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
)
|
||||
|
||||
driver = None
|
||||
try:
|
||||
driver = get_db_connection()
|
||||
with driver.session() as session:
|
||||
result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude)
|
||||
return [(record["name"], record["frequency"]) for record in result]
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
|
||||
async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
for tag, freq in raw_tags_with_freq:
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始获取热门记忆标签...")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
group_id_to_query = SELECTED_GROUP_ID
|
||||
# 使用 asyncio.run 来执行异步主函数
|
||||
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
|
||||
|
||||
if top_tags:
|
||||
print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):")
|
||||
for tag, frequency in top_tags:
|
||||
print(f"- {tag} (数量: {frequency})")
|
||||
|
||||
# --- 将结果写入统一的 Signboard.json 到 logs/memory-output ---
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
signboard_path = settings.get_memory_output_path("Signboard.json")
|
||||
payload = {
|
||||
"group_id": group_id_to_query,
|
||||
"hot_tags": [{"name": t, "frequency": f} for t, f in top_tags]
|
||||
}
|
||||
try:
|
||||
existing = {}
|
||||
if os.path.exists(signboard_path):
|
||||
with open(signboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["hot_memory_tags"] = payload
|
||||
with open(signboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {signboard_path} -> hot_memory_tags")
|
||||
except Exception as e:
|
||||
print(f"写入 Signboard.json 失败: {e}")
|
||||
else:
|
||||
print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。")
|
||||
except Exception as e:
|
||||
print(f"执行过程中发生严重错误: {e}")
|
||||
print("请检查:")
|
||||
print("1. Neo4j数据库服务是否正在运行。")
|
||||
print("2. 'config.json'中的配置是否正确。")
|
||||
print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。")
|
||||
343
api/app/core/memory/analytics/memory_insight.py
Normal file
343
api/app/core/memory/analytics/memory_insight.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
This module provides the MemoryInsight class for analyzing user memory data.
|
||||
|
||||
This script can be executed directly to generate a memory insight report for a test user.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
# To run this script directly, we need to add the src directory to the Python path
|
||||
# to resolve the inconsistent imports in other modules.
|
||||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
"""
|
||||
Represents the classification of a tag into a specific domain.
|
||||
"""
|
||||
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||
)
|
||||
|
||||
class InsightReport(BaseModel):
|
||||
"""
|
||||
Represents the final insight report generated by the LLM.
|
||||
"""
|
||||
|
||||
report: str = Field(
|
||||
...,
|
||||
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryInsight:
|
||||
"""
|
||||
Provides insights into user memories by analyzing various aspects of their data.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
async def get_domain_distribution(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculates the distribution of memory domains based on hot tags.
|
||||
"""
|
||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||
if not hot_tags:
|
||||
return {}
|
||||
|
||||
domain_counts = Counter()
|
||||
for tag, _ in hot_tags:
|
||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||
- 其他:确实无法归入以上任何类别的内容
|
||||
|
||||
标签: {tag}
|
||||
|
||||
分析步骤:
|
||||
1. 仔细理解标签的核心含义和使用场景
|
||||
2. 对比各个领域的关键词,找到最匹配的领域
|
||||
3. 特别注意:
|
||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
# 直接调用并等待结果
|
||||
classification = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=TagClassification,
|
||||
)
|
||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||
domain_counts[classification.domain] += 1
|
||||
|
||||
total_tags = sum(domain_counts.values())
|
||||
if total_tags == 0:
|
||||
return {}
|
||||
|
||||
domain_distribution = {
|
||||
domain: count / total_tags for domain, count in domain_counts.items()
|
||||
}
|
||||
return dict(
|
||||
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
|
||||
)
|
||||
|
||||
async def get_active_periods(self) -> list[int]:
|
||||
"""
|
||||
Identifies the top 2 most active months for the user.
|
||||
Only returns months if there is valid and diverse time data.
|
||||
|
||||
This method checks if the time data represents real user memory timestamps
|
||||
rather than auto-generated system timestamps by verifying:
|
||||
1. Time data exists and is parseable
|
||||
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
month_counts = Counter()
|
||||
valid_dates_count = 0
|
||||
for record in records:
|
||||
creation_time_str = record.get("creation_time")
|
||||
if not creation_time_str:
|
||||
continue
|
||||
try:
|
||||
# 尝试解析时间字符串
|
||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||
month_counts[dt_object.month] += 1
|
||||
valid_dates_count += 1
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
# 如果解析失败,跳过这条记录
|
||||
continue
|
||||
|
||||
# 如果没有有效的时间数据,返回空列表
|
||||
if not month_counts or valid_dates_count == 0:
|
||||
return []
|
||||
|
||||
# 检查时间分布是否过于集中(可能是批量导入的数据)
|
||||
# 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间
|
||||
unique_months = len(month_counts)
|
||||
if unique_months <= 2:
|
||||
# 只有1-2个月有数据,很可能是批量导入
|
||||
most_common_count = month_counts.most_common(1)[0][1]
|
||||
if most_common_count / valid_dates_count > 0.8:
|
||||
# 超过80%集中在一个月,认为是系统时间戳
|
||||
return []
|
||||
|
||||
# 如果时间分布较为分散(3个月以上),认为是真实时间数据
|
||||
if unique_months >= 3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 2个月的情况,检查是否分布均匀
|
||||
if unique_months == 2:
|
||||
counts = list(month_counts.values())
|
||||
# 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据
|
||||
ratio = min(counts) / max(counts)
|
||||
if ratio > 0.3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 其他情况返回空列表
|
||||
return []
|
||||
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""
|
||||
Finds the user with whom the most memories are shared.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d1:Dialogue {{group_id: '{self.user_id}'}})<-[:MENTIONS]-(s:Statement)-[:MENTIONS]->(d2:Dialogue)
|
||||
WHERE d1 <> d2
|
||||
RETURN d2.group_id AS other_user_id, COUNT(s) AS common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
most_connected_user = records[0]["other_user_id"]
|
||||
common_memories_count = records[0]["common_statements"]
|
||||
|
||||
time_range_query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id IN ['{self.user_id}', '{most_connected_user}']
|
||||
RETURN min(d.created_at) AS start_time, max(d.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(time_range_query)
|
||||
start_year, end_year = "N/A", "N/A"
|
||||
if time_records and time_records[0]["start_time"]:
|
||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||
|
||||
return {
|
||||
"user_id": most_connected_user,
|
||||
"common_memories_count": common_memories_count,
|
||||
"time_range": f"{start_year}-{end_year}",
|
||||
}
|
||||
|
||||
async def generate_insight_report(self) -> str:
|
||||
"""
|
||||
Generates the final insight report in natural language.
|
||||
"""
|
||||
domain_dist, active_periods, social_conn = await asyncio.gather(
|
||||
self.get_domain_distribution(),
|
||||
self.get_active_periods(),
|
||||
self.get_social_connections(),
|
||||
)
|
||||
|
||||
prompt_parts = []
|
||||
|
||||
if domain_dist:
|
||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
||||
prompt_parts.append(f"- 核心领域: 用户的记忆主要集中在 {top_domains}。")
|
||||
|
||||
if active_periods:
|
||||
months_str = " 和 ".join(map(str, active_periods))
|
||||
prompt_parts.append(f"- 活跃时段: 用户在每年的 {months_str} 月最为活跃。")
|
||||
|
||||
if social_conn:
|
||||
prompt_parts.append(
|
||||
f"- 社交关联: 与用户\"{social_conn['user_id']}\"拥有最多共同记忆({social_conn['common_memories_count']}条),时间范围主要在 {social_conn['time_range']}。"
|
||||
)
|
||||
|
||||
if not prompt_parts:
|
||||
return "暂无足够数据生成洞察报告。"
|
||||
|
||||
system_prompt = '''你是一位资深的个人记忆分析师。你的任务是根据我提供的要点,为用户生成一段简洁、自然、个性化的记忆洞察报告。
|
||||
|
||||
重要规则:
|
||||
1. 报告需要将所有要点流畅地串联成一个段落
|
||||
2. 语言风格要亲切、易于理解,就像和朋友聊天一样
|
||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
||||
4. 只使用我提供的要点,不要编造或推测任何信息
|
||||
5. 如果某个维度没有数据(如没有活跃时段信息),就不要在报告中提及该维度
|
||||
|
||||
例如,如果输入是:
|
||||
- 核心领域: 用户的记忆主要集中在 旅行(38%), 工作(24%), 家庭(21%)。
|
||||
- 活跃时段: 用户在每年的 4 和 10 月最为活跃。
|
||||
- 社交关联: 与用户"张明"拥有最多共同记忆(47条),时间范围主要在 2017-2020。
|
||||
|
||||
你的输出应该是:
|
||||
"您的记忆集中在旅行(38%)、工作(24%)和家庭(21%)三大领域。每年4月和10月是您最活跃的记录期,可能与春秋季旅行计划相关。您与'张明'共同拥有最多记忆(47条),主要集中在2017-2020年间。"
|
||||
|
||||
如果输入只有:
|
||||
- 核心领域: 用户的记忆主要集中在 教育(65%), 学习(25%)。
|
||||
|
||||
你的输出应该是:
|
||||
"您的记忆主要集中在教育(65%)和学习(25%)两大领域,显示出您对知识和成长的持续关注。"'''
|
||||
|
||||
user_prompt = "\n".join(prompt_parts)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.chat(messages=messages)
|
||||
|
||||
return response.content
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the database connection.
|
||||
"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
test_user_id = SELECTED_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
insight = None
|
||||
try:
|
||||
insight = MemoryInsight(user_id=test_user_id)
|
||||
report = await insight.generate_insight_report()
|
||||
print("--- 记忆洞察报告 ---")
|
||||
print(report)
|
||||
print("---------------------")
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json,使用全局配置路径
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["memory_insight"] = {
|
||||
"group_id": test_user_id,
|
||||
"report": report
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> memory_insight")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成报告时出错: {e}")
|
||||
finally:
|
||||
if insight:
|
||||
await insight.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This setup allows running the async main function
|
||||
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
202
api/app/core/memory/analytics/recent_activity_stats.py
Normal file
202
api/app/core/memory/analytics/recent_activity_stats.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
||||
except Exception:
|
||||
# Fallback: derive project root from this file location
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def _get_latest_prompt_log_path() -> str | None:
|
||||
"""Return the latest prompt log file path under PROJECT_ROOT/logs, or None."""
|
||||
logs_dir = os.path.join(PROJECT_ROOT, "logs", "prompts")
|
||||
if not os.path.isdir(logs_dir):
|
||||
return None
|
||||
|
||||
files = glob.glob(os.path.join(logs_dir, "prompt_logs-*.log"))
|
||||
if not files:
|
||||
return None
|
||||
|
||||
# Choose by modified time descending
|
||||
files.sort(key=lambda p: os.path.getmtime(p), reverse=True)
|
||||
return files[0]
|
||||
|
||||
|
||||
def _get_all_prompt_logs() -> list[str]:
|
||||
"""Return all log file paths under logs dirs sorted by mtime ascending.
|
||||
|
||||
It checks both PROJECT_ROOT/logs/prompts and CWD/logs/prompts to be robust.
|
||||
"""
|
||||
candidates = []
|
||||
pr_logs = os.path.join(PROJECT_ROOT, "logs", "prompts")
|
||||
cwd_logs = os.path.join(os.getcwd(), "logs", "prompts")
|
||||
for d in [pr_logs, cwd_logs]:
|
||||
if os.path.isdir(d):
|
||||
candidates.extend(glob.glob(os.path.join(d, "prompt_logs-*.log")))
|
||||
|
||||
# Deduplicate and sort
|
||||
files = sorted(set(candidates), key=lambda p: os.path.getmtime(p))
|
||||
return files
|
||||
|
||||
|
||||
def _get_any_logs_recursive() -> list[str]:
|
||||
"""Fallback: search for any .log files under PROJECT_ROOT recursively."""
|
||||
files = glob.glob(os.path.join(PROJECT_ROOT, "**", "*.log"), recursive=True)
|
||||
files.sort(key=lambda p: os.path.getmtime(p))
|
||||
return files
|
||||
|
||||
|
||||
def parse_stats_from_log(log_path: str) -> dict:
|
||||
"""
|
||||
Parse required statistics from a prompt log file.
|
||||
|
||||
Returns dict with keys:
|
||||
- chunk_count: int (count of chunks processed)
|
||||
- statements_count: int (total statements processed for triplets)
|
||||
- triplet_entities_count: int (total entities extracted)
|
||||
- triplet_relations_count: int (total triplets/relations extracted)
|
||||
- temporal_count: int (extracted valid temporal ranges)
|
||||
"""
|
||||
chunk_count = 0
|
||||
statements_count = 0
|
||||
triplet_entities_count = 0
|
||||
triplet_relations_count = 0
|
||||
temporal_count = 0
|
||||
|
||||
# Patterns
|
||||
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
||||
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
|
||||
pat_triplet_done = re.compile(
|
||||
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
|
||||
)
|
||||
pat_temporal_done = re.compile(
|
||||
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
|
||||
)
|
||||
|
||||
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
for line in f:
|
||||
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
|
||||
if pat_chunk_render.search(line):
|
||||
chunk_count += 1
|
||||
continue
|
||||
|
||||
m1 = pat_triplet_start.search(line)
|
||||
if m1:
|
||||
try:
|
||||
statements_count += int(m1.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m2 = pat_triplet_done.search(line)
|
||||
if m2:
|
||||
try:
|
||||
triplet_relations_count += int(m2.group(1))
|
||||
triplet_entities_count += int(m2.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m3 = pat_temporal_done.search(line)
|
||||
if m3:
|
||||
try:
|
||||
temporal_count += int(m3.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
return {
|
||||
"chunk_count": chunk_count,
|
||||
"statements_count": statements_count,
|
||||
"triplet_entities_count": triplet_entities_count,
|
||||
"triplet_relations_count": triplet_relations_count,
|
||||
"temporal_count": temporal_count,
|
||||
"log_path": log_path,
|
||||
}
|
||||
|
||||
|
||||
def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"""Get aggregated stats from all prompt logs in logs/.
|
||||
|
||||
Returns (stats_dict, message).
|
||||
"""
|
||||
all_logs = _get_all_prompt_logs()
|
||||
# Fallback to recursive search if none found in logs/
|
||||
if not all_logs:
|
||||
all_logs = _get_any_logs_recursive()
|
||||
if not all_logs:
|
||||
return (
|
||||
{
|
||||
"chunk_count": 0,
|
||||
"statements_count": 0,
|
||||
"triplet_entities_count": 0,
|
||||
"triplet_relations_count": 0,
|
||||
"temporal_count": 0,
|
||||
"log_path": None,
|
||||
},
|
||||
"未找到日志文件,请确认已运行过提取流程。",
|
||||
)
|
||||
|
||||
agg = {
|
||||
"chunk_count": 0,
|
||||
"statements_count": 0,
|
||||
"triplet_entities_count": 0,
|
||||
"triplet_relations_count": 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
for path in all_logs:
|
||||
s = parse_stats_from_log(path)
|
||||
agg["chunk_count"] += s.get("chunk_count", 0)
|
||||
agg["statements_count"] += s.get("statements_count", 0)
|
||||
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
|
||||
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
|
||||
agg["temporal_count"] += s.get("temporal_count", 0)
|
||||
|
||||
# Attach a summary of files combined
|
||||
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
|
||||
return agg, "成功汇总 logs 目录中所有提示日志。"
|
||||
|
||||
|
||||
def _format_summary(stats: dict) -> str:
|
||||
"""Format a Chinese summary string from stats."""
|
||||
log_info = stats.get("log_path") or "(无)"
|
||||
return (
|
||||
"最近记忆活动统计\n"
|
||||
f"- 日志文件:{log_info}\n"
|
||||
f"- 数据分块:共 {stats.get('chunk_count', 0)} 块\n"
|
||||
f"- 句子提取:共 {stats.get('statements_count', 0)} 个句子\n"
|
||||
f"- 三元组提取:实体 {stats.get('triplet_entities_count', 0)} 个,关系 {stats.get('triplet_relations_count', 0)} 条\n"
|
||||
f"- 时间提取:共提取 {stats.get('temporal_count', 0)} 条时间信息\n"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
stats, msg = get_recent_activity_stats()
|
||||
print(msg)
|
||||
print(_format_summary(stats))
|
||||
|
||||
# --- 将结果写入统一的 Signboard.json ---
|
||||
try:
|
||||
# 使用全局配置的输出路径
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
signboard_path = os.path.join(output_dir, "Signboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(signboard_path):
|
||||
with open(signboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["recent_activity_stats"] = stats
|
||||
with open(signboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {signboard_path} -> recent_activity_stats")
|
||||
except Exception as e:
|
||||
print(f"写入 Signboard.json 失败: {e}")
|
||||
152
api/app/core/memory/analytics/user_summary.py
Normal file
152
api/app/core/memory/analytics/user_summary.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Generate a concise "关于我" style user summary using data from Neo4j
|
||||
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
|
||||
|
||||
Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Ensure absolute imports work whether executed directly or via module
|
||||
try:
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementRecord:
|
||||
statement: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
class UserSummary:
|
||||
"""Builds a textual user summary for a given user/group id."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]:
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
records: List[StatementRecord] = []
|
||||
for r in rows:
|
||||
try:
|
||||
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
|
||||
except Exception:
|
||||
continue
|
||||
return records
|
||||
|
||||
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
|
||||
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
|
||||
return await get_hot_memory_tags(self.user_id, limit=limit)
|
||||
|
||||
async def generate(self) -> str:
|
||||
"""Generate a Chinese '关于我' style summary using the LLM."""
|
||||
# 1) Collect context
|
||||
entities = await self._get_top_entities(limit=40)
|
||||
statements = await self._get_recent_statements(limit=100)
|
||||
|
||||
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
|
||||
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
|
||||
|
||||
# 2) Compose prompt
|
||||
system_prompt = (
|
||||
"你是一位中文信息压缩助手。请基于提供的实体与语句,"
|
||||
"生成非常简洁的用户摘要,禁止臆测或虚构。要求:\n"
|
||||
"- 3–4 句,总字数不超过 120;\n"
|
||||
"- 先交代身份/城市,其次长期兴趣或习惯,最后给一两项代表性经历;\n"
|
||||
"- 避免形容词堆砌与空话,不用项目符号,不分段;\n"
|
||||
"- 使用客观的第三人称描述,语气克制、中立。"
|
||||
)
|
||||
|
||||
user_content_parts = [
|
||||
f"用户ID: {self.user_id}",
|
||||
"核心实体与频次: " + (", ".join(entity_lines) if entity_lines else "(空)"),
|
||||
"代表性语句样本: " + (" | ".join(statement_samples) if statement_samples else "(空)"),
|
||||
]
|
||||
user_prompt = "\n".join(user_content_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# 3) Call LLM
|
||||
response = await self.llm.chat(messages=messages)
|
||||
return response.content
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str:
|
||||
# 默认从 runtime.json selections.group_id 读取
|
||||
effective_group_id = user_id or SELECTED_GROUP_ID
|
||||
svc = UserSummary(effective_group_id)
|
||||
try:
|
||||
return await svc.generate()
|
||||
finally:
|
||||
await svc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始生成用户摘要…")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
summary = asyncio.run(generate_user_summary())
|
||||
print("\n— 用户摘要 —\n")
|
||||
print(summary)
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": SELECTED_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> user_summary")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成摘要失败: {e}")
|
||||
print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。")
|
||||
132
api/app/core/memory/config.json
Normal file
132
api/app/core/memory/config.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"llm_list": [
|
||||
{
|
||||
"llm_name": "qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_AGENT_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen2.5-14b-instruct-awq",
|
||||
"api_base": "http://175.27.131.196:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-14b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/deepseek-r1-0528-qwen3-8b",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"api_key": "OPENAI_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/qwen3-235b-a22b-instruct-2507",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
}
|
||||
,
|
||||
{
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "DASHSCOPE_API_KEY"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
},
|
||||
{
|
||||
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
}
|
||||
],
|
||||
"embedding_list": [
|
||||
{
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"api_base": "http://119.45.239.97:11434/v1",
|
||||
"dimension": 768
|
||||
},
|
||||
{
|
||||
"embedding_name": "openai/bge-m3",
|
||||
"api_base": "http://43.137.4.24:9090/v1",
|
||||
"dimension": 1024
|
||||
}
|
||||
],
|
||||
"neo4j": {
|
||||
"uri": "bolt://1.94.111.67:7687",
|
||||
"username": "neo4j"
|
||||
},
|
||||
"chunker_list": [
|
||||
{
|
||||
"chunker_strategy": "TokenChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 56,
|
||||
"tokenizer_or_token_counter": "character"
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"min_characters_per_chunk": 50
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SemanticChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1024,
|
||||
"threshold": 0.8,
|
||||
"min_sentences": 2,
|
||||
"skip_window": 1,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LateChunker",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size": 2048,
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "NeuralChunker",
|
||||
"embedding_model": "mirth/chonky_modernbert_base_1",
|
||||
"min_characters_per_chunk": 24
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1000,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "HybridChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"threshold": 0.8,
|
||||
"min_characters_per_chunk": 100
|
||||
},
|
||||
{
|
||||
"chunker_strategy": "SentenceChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 2048,
|
||||
"chunk_overlap": 128,
|
||||
"min_sentences_per_chunk": 1,
|
||||
"min_characters_per_sentence": 12,
|
||||
"delim": [".", "!", "?", "\n"],
|
||||
"include_delim": "prev",
|
||||
"tokenizer_or_token_counter": "character"
|
||||
}
|
||||
],
|
||||
"langfuse": {
|
||||
"enabled": true
|
||||
},
|
||||
"agenta": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
12
api/app/core/memory/data/testdata.json
Normal file
12
api/app/core/memory/data/testdata.json
Normal file
File diff suppressed because one or more lines are too long
5
api/app/core/memory/dbrun.json
Normal file
5
api/app/core/memory/dbrun.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"selections": {
|
||||
"config_id": "1"
|
||||
}
|
||||
}
|
||||
1
api/app/core/memory/evaluation/__init__.py
Normal file
1
api/app/core/memory/evaluation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
||||
30
api/app/core/memory/evaluation/benchmark.md
Normal file
30
api/app/core/memory/evaluation/benchmark.md
Normal file
@@ -0,0 +1,30 @@
|
||||
⏬数据集下载地址:
|
||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
||||
|
||||
全流程基准测试运行:
|
||||
locomo:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
||||
LongMemEval:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
||||
memsciqa:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
||||
|
||||
单独检索评估运行命令:
|
||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
||||
需要先在项目中修改需要检测评估的group_id。
|
||||
|
||||
参数及解释:
|
||||
● --dataset longmemeval - 指定数据集
|
||||
● --sample-size 10 - 评估10个样本
|
||||
● --start-index 0 - 从第0个样本开始
|
||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
||||
● --search-limit 8 - 检索限制8条
|
||||
● --context-char-budget 4000 - 上下文字符预算4000
|
||||
● --search-type hybrid - 使用混合检索
|
||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
||||
● --reset-group - 运行前清空组数据
|
||||
100
api/app/core/memory/evaluation/common/metrics.py
Normal file
100
api/app/core/memory/evaluation/common/metrics.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = [t for t in text.split() if t]
|
||||
return tokens
|
||||
|
||||
|
||||
def exact_match(pred: str, ref: str) -> float:
|
||||
return float(_normalize(pred) == _normalize(ref))
|
||||
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
p = set(_normalize(pred))
|
||||
r = set(_normalize(ref))
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
# Clipped count
|
||||
r_counts: Dict[str, int] = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
clipped = 0
|
||||
p_counts: Dict[str, int] = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
# Brevity penalty
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
return bp * precision
|
||||
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
vals = sorted(values)
|
||||
k = (len(vals) - 1) * p
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return vals[int(k)]
|
||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
||||
|
||||
|
||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
||||
if not latencies_ms:
|
||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
||||
p25 = percentile(latencies_ms, 0.25)
|
||||
p50 = percentile(latencies_ms, 0.50)
|
||||
p75 = percentile(latencies_ms, 0.75)
|
||||
p95 = percentile(latencies_ms, 0.95)
|
||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
||||
|
||||
|
||||
def avg_context_tokens(contexts: List[str]) -> float:
|
||||
if not contexts:
|
||||
return 0.0
|
||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
||||
60
api/app/core/memory/evaluation/dialogue_queries.py
Normal file
60
api/app/core/memory/evaluation/dialogue_queries.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Dialogue search queries for evaluation purposes.
|
||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
# Chunk search queries
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.content CONTAINS $content
|
||||
RETURN c
|
||||
"""
|
||||
|
||||
# Dialogue search queries
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_id = $dialog_id
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.content CONTAINS $q
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
326
api/app/core/memory/evaluation/extraction_utils.py
Normal file
326
api/app/core/memory/evaluation/extraction_utils.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
group_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
) -> bool:
|
||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
||||
|
||||
# Initialize llm client with graceful fallback
|
||||
llm_client = None
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
|
||||
# Step A: Build DialogData list from contexts with robust parsing
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
for idx, ctx in enumerate(contexts):
|
||||
messages: List[ConversationMessage] = []
|
||||
|
||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
||||
|
||||
if matches:
|
||||
for m in matches:
|
||||
raw_role = m.group(1).strip()
|
||||
content = m.group(2).strip()
|
||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
||||
else:
|
||||
# Fallback: line-by-line parsing
|
||||
for raw in ctx.split("\n"):
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
||||
else:
|
||||
# Final fallback: treat as user message
|
||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
||||
|
||||
context_model = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
group_id=group_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
# Generate chunks
|
||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("No dialogs to process for ingestion.")
|
||||
return False
|
||||
|
||||
# Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
||||
out_path = save_chunk_output_path or default_path
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save chunking results: {e}")
|
||||
|
||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
||||
if not llm_available:
|
||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
||||
return False
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化并运行 ExtractionOrchestrator
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
||||
# 保存原始的 _assign_extracted_data 方法
|
||||
original_assign = orchestrator._assign_extracted_data
|
||||
|
||||
def clean_temporal_value(value):
|
||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
||||
return None
|
||||
return value
|
||||
|
||||
async def patched_assign_extracted_data(*args, **kwargs):
|
||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
||||
result = await original_assign(*args, **kwargs)
|
||||
|
||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
||||
for dialog in result:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
# 清理 valid_at 和 invalid_at
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
return result
|
||||
|
||||
# 替换方法
|
||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
||||
|
||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
||||
original_create = orchestrator._create_nodes_and_edges
|
||||
|
||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
||||
# 最后一次清理,确保万无一失
|
||||
for dialog in dialog_data_list_arg:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
|
||||
return await original_create(dialog_data_list_arg)
|
||||
|
||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run 返回 7 个元素的元组
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
|
||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
||||
|
||||
# Step G: 生成记忆摘要
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
Memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
|
||||
summaries = await Memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step H: Save to Neo4j
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
entity_edges=entity_entity_edges,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
# Save memory summaries separately
|
||||
if summaries:
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
if success:
|
||||
print("Successfully saved extracted data to Neo4j!")
|
||||
else:
|
||||
print("Failed to save data to Neo4j")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Failed to save data to Neo4j: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_context_processing(args):
|
||||
"""Handle context-based processing from command line arguments."""
|
||||
contexts = []
|
||||
|
||||
if args.contexts:
|
||||
contexts.extend(args.contexts)
|
||||
|
||||
if args.context_file:
|
||||
try:
|
||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
||||
contexts.extend(line.strip() for line in f if line.strip())
|
||||
except Exception as e:
|
||||
print(f"Error reading context file: {e}")
|
||||
return False
|
||||
|
||||
if not contexts:
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_group_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
group_id=group_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Successfully processed and saved contexts to Neo4j!")
|
||||
else:
|
||||
print("Failed to process contexts.")
|
||||
|
||||
return success
|
||||
568
api/app/core/memory/evaluation/locomo/locomo_benchmark.py
Normal file
568
api/app/core/memory/evaluation/locomo/locomo_benchmark.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
LoCoMo Benchmark Script
|
||||
|
||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
||||
in a clean, maintainable way.
|
||||
|
||||
Usage:
|
||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID
|
||||
)
|
||||
from app.core.memory.utils.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
f1_score,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
avg_context_tokens
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
get_category_name
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
load_locomo_data,
|
||||
extract_conversations,
|
||||
resolve_temporal_references,
|
||||
select_and_format_information,
|
||||
retrieve_relevant_information,
|
||||
ingest_conversations_if_needed
|
||||
)
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
group_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
reset_group: bool = False,
|
||||
skip_ingest: bool = False,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run LoCoMo benchmark evaluation.
|
||||
|
||||
This function orchestrates the complete evaluation pipeline:
|
||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
||||
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
|
||||
3. For each question:
|
||||
- Retrieve relevant information
|
||||
- Generate answer using LLM
|
||||
- Calculate metrics
|
||||
4. Aggregate results and save to file
|
||||
|
||||
Note: By default, only the first conversation is ingested into the database,
|
||||
and only QA pairs from that conversation are evaluated. This ensures that
|
||||
all questions have corresponding memory in the database for retrieval.
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
group_id: Database group ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
reset_group: Whether to clear and re-ingest data (not implemented)
|
||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default group_id if not provided
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
# Determine data path
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
# Fallback to current directory
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" Group ID: {group_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
print(f" Data path: {data_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Step 1: Load LoCoMo data
|
||||
print("📂 Loading LoCoMo dataset...")
|
||||
try:
|
||||
# Only load QA pairs from the first conversation (index 0)
|
||||
# since we only ingest the first conversation into the database
|
||||
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
|
||||
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load data: {e}")
|
||||
return {
|
||||
"error": f"Data loading failed: {e}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Step 2: Extract conversations and ingest if needed
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" Group ID: {group_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
conversations = extract_conversations(data_path, max_dialogues=1)
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
group_id=group_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Ingestion completed successfully\n")
|
||||
else:
|
||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ingestion failed: {e}")
|
||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
||||
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
print("✅ Clients initialized\n")
|
||||
|
||||
# Step 4: Process questions
|
||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Tracking variables
|
||||
latencies_search: List[float] = []
|
||||
latencies_llm: List[float] = []
|
||||
context_counts: List[int] = []
|
||||
context_chars: List[int] = []
|
||||
context_tokens: List[int] = []
|
||||
|
||||
# Metric lists
|
||||
f1_scores: List[float] = []
|
||||
bleu1_scores: List[float] = []
|
||||
jaccard_scores: List[float] = []
|
||||
locomo_f1_scores: List[float] = []
|
||||
|
||||
# Per-category tracking
|
||||
category_counts: Dict[str, int] = {}
|
||||
category_f1: Dict[str, List[float]] = {}
|
||||
category_bleu1: Dict[str, List[float]] = {}
|
||||
category_jaccard: Dict[str, List[float]] = {}
|
||||
category_locomo_f1: Dict[str, List[float]] = {}
|
||||
|
||||
# Detailed samples
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
# Fixed anchor date for temporal resolution
|
||||
anchor_date = datetime(2023, 5, 8)
|
||||
|
||||
try:
|
||||
for idx, item in enumerate(qa_items, 1):
|
||||
question = item.get("question", "")
|
||||
ground_truth = item.get("answer", "")
|
||||
category = get_category_name(item)
|
||||
|
||||
# Ensure ground truth is a string
|
||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
||||
print(f"❓ Question: {question}")
|
||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
||||
|
||||
# Step 4a: Retrieve relevant information
|
||||
t_search_start = time.time()
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
group_id=group_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
embedder=embedder
|
||||
)
|
||||
t_search_end = time.time()
|
||||
search_latency = (t_search_end - t_search_start) * 1000
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Retrieval failed: {e}")
|
||||
retrieved_info = []
|
||||
search_latency = 0.0
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
# Step 4b: Select and format context
|
||||
context_text = select_and_format_information(
|
||||
retrieved_info=retrieved_info,
|
||||
question=question,
|
||||
max_chars=context_char_budget
|
||||
)
|
||||
|
||||
# Resolve temporal references
|
||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
||||
|
||||
# Add reference date to context
|
||||
if context_text:
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
||||
else:
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# Track context statistics
|
||||
context_counts.append(len(retrieved_info))
|
||||
context_chars.append(len(context_text))
|
||||
context_tokens.append(len(context_text.split()))
|
||||
|
||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
||||
|
||||
# Step 4c: Generate answer with LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
||||
}
|
||||
]
|
||||
|
||||
t_llm_start = time.time()
|
||||
try:
|
||||
response = await llm_client.chat(messages=messages)
|
||||
t_llm_end = time.time()
|
||||
llm_latency = (t_llm_end - t_llm_start) * 1000
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Extract prediction from response
|
||||
if hasattr(response, 'content'):
|
||||
prediction = response.content.strip()
|
||||
elif isinstance(response, dict):
|
||||
prediction = response["choices"][0]["message"]["content"].strip()
|
||||
else:
|
||||
prediction = "Unknown"
|
||||
|
||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLM failed: {e}")
|
||||
prediction = "Unknown"
|
||||
llm_latency = 0.0
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Step 4d: Calculate metrics
|
||||
f1_val = f1_score(prediction, ground_truth_str)
|
||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
||||
|
||||
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
|
||||
if item.get("category") == 1:
|
||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
||||
else:
|
||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
||||
|
||||
# Accumulate metrics
|
||||
f1_scores.append(f1_val)
|
||||
bleu1_scores.append(bleu1_val)
|
||||
jaccard_scores.append(jaccard_val)
|
||||
locomo_f1_scores.append(locomo_f1_val)
|
||||
|
||||
# Track by category
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
category_f1.setdefault(category, []).append(f1_val)
|
||||
category_bleu1.setdefault(category, []).append(bleu1_val)
|
||||
category_jaccard.setdefault(category, []).append(jaccard_val)
|
||||
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
|
||||
|
||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
||||
print()
|
||||
|
||||
# Save sample details
|
||||
samples.append({
|
||||
"question": question,
|
||||
"ground_truth": ground_truth_str,
|
||||
"prediction": prediction,
|
||||
"category": category,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"bleu1": bleu1_val,
|
||||
"jaccard": jaccard_val,
|
||||
"locomo_f1": locomo_f1_val
|
||||
},
|
||||
"retrieval": {
|
||||
"num_docs": len(retrieved_info),
|
||||
"context_length": len(context_text)
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_latency,
|
||||
"llm_ms": llm_latency
|
||||
}
|
||||
})
|
||||
|
||||
finally:
|
||||
# Close connector
|
||||
await connector.close()
|
||||
|
||||
# Step 5: Aggregate results
|
||||
print(f"\n{'='*60}")
|
||||
print("📊 Aggregating Results")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Overall metrics
|
||||
overall_metrics = {
|
||||
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
|
||||
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
|
||||
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
|
||||
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
by_category: Dict[str, Dict[str, Any]] = {}
|
||||
for cat in category_counts:
|
||||
f1_list = category_f1.get(cat, [])
|
||||
b1_list = category_bleu1.get(cat, [])
|
||||
j_list = category_jaccard.get(cat, [])
|
||||
lf_list = category_locomo_f1.get(cat, [])
|
||||
|
||||
by_category[cat] = {
|
||||
"count": category_counts[cat],
|
||||
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
|
||||
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
|
||||
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
|
||||
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
|
||||
}
|
||||
|
||||
# Latency statistics
|
||||
latency = {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm)
|
||||
}
|
||||
|
||||
# Context statistics
|
||||
context_stats = {
|
||||
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
|
||||
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
|
||||
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
|
||||
}
|
||||
|
||||
# Build result dictionary
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"overall_metrics": overall_metrics,
|
||||
"by_category": by_category,
|
||||
"latency": latency,
|
||||
"context_stats": context_stats,
|
||||
"samples": samples
|
||||
}
|
||||
|
||||
# Step 6: Save results
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"results"
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate timestamped filename
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ Results saved to: {output_path}\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
print("📊 Printing results to console instead:\n")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse command-line arguments and run benchmark.
|
||||
|
||||
This function provides a CLI interface for running LoCoMo benchmarks
|
||||
with configurable parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LoCoMo benchmark evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_size",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of QA pairs to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database group ID for retrieval (uses default if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
help="Search strategy to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_limit",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Maximum number of documents to retrieve per query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_char_budget",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Maximum characters for context"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_group",
|
||||
action="store_true",
|
||||
help="Clear and re-ingest data (not implemented)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ingest",
|
||||
action="store_true",
|
||||
help="Skip data ingestion and use existing data in Neo4j"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results (uses default if not specified)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
reset_group=args.reset_group,
|
||||
skip_ingest=args.skip_ingest,
|
||||
output_dir=args.output_dir
|
||||
))
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print("❌ Benchmark Failed!")
|
||||
print(f"{'='*60}")
|
||||
print(f"Error: {result['error']}")
|
||||
return
|
||||
|
||||
print("🎉 Benchmark Complete!")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Final Results:")
|
||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
||||
|
||||
if result.get('context_stats'):
|
||||
print("\n📈 Context Statistics:")
|
||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
||||
|
||||
if result.get('latency'):
|
||||
print("\n⏱️ Latency Statistics:")
|
||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
||||
|
||||
if result.get('by_category'):
|
||||
print("\n📂 Results by Category:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" Count: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
225
api/app/core/memory/evaluation/locomo/locomo_metrics.py
Normal file
225
api/app/core/memory/evaluation/locomo/locomo_metrics.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
LoCoMo-specific metric calculations.
|
||||
|
||||
This module provides clean, simplified implementations of metrics used for
|
||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
Normalize text for LoCoMo evaluation.
|
||||
|
||||
Normalization steps:
|
||||
- Convert to lowercase
|
||||
- Remove commas
|
||||
- Remove stop words (a, an, the, and)
|
||||
- Remove punctuation
|
||||
- Normalize whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text string with consistent formatting
|
||||
|
||||
Examples:
|
||||
>>> normalize_text("The cat, and the dog")
|
||||
'cat dog'
|
||||
>>> normalize_text("Hello, World!")
|
||||
'hello world'
|
||||
"""
|
||||
# Ensure input is a string
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove commas
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
|
||||
# Remove stop words
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
|
||||
# Remove punctuation (keep only word characters and whitespace)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces to single space)
|
||||
text = " ".join(text.split())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for single-answer questions.
|
||||
|
||||
Uses token-level precision and recall based on normalized text.
|
||||
Treats tokens as sets (no duplicate counting).
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer
|
||||
ground_truth: Correct answer
|
||||
|
||||
Returns:
|
||||
F1 score between 0.0 and 1.0
|
||||
|
||||
Examples:
|
||||
>>> locomo_f1_score("Paris", "Paris")
|
||||
1.0
|
||||
>>> locomo_f1_score("The cat", "cat")
|
||||
1.0
|
||||
>>> locomo_f1_score("dog", "cat")
|
||||
0.0
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Normalize and tokenize
|
||||
pred_tokens = normalize_text(pred_str).split()
|
||||
truth_tokens = normalize_text(truth_str).split()
|
||||
|
||||
# Handle empty cases
|
||||
if not pred_tokens or not truth_tokens:
|
||||
return 0.0
|
||||
|
||||
# Convert to sets for comparison
|
||||
pred_set = set(pred_tokens)
|
||||
truth_set = set(truth_tokens)
|
||||
|
||||
# Calculate true positives (intersection)
|
||||
true_positives = len(pred_set & truth_set)
|
||||
|
||||
# Calculate precision and recall
|
||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
||||
|
||||
# Calculate F1 score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for multi-answer questions.
|
||||
|
||||
Handles comma-separated answers by:
|
||||
1. Splitting both prediction and ground truth by commas
|
||||
2. For each ground truth answer, finding the best matching prediction
|
||||
3. Averaging the F1 scores across all ground truth answers
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
||||
|
||||
Returns:
|
||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
||||
|
||||
Examples:
|
||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
||||
1.0
|
||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
||||
0.5
|
||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
||||
0.5
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Split by commas and strip whitespace
|
||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
||||
|
||||
# Handle empty cases
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
|
||||
# For each ground truth, find the best matching prediction
|
||||
f1_scores = []
|
||||
for gt in ground_truths:
|
||||
# Calculate F1 with each prediction and take the maximum
|
||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
||||
f1_scores.append(best_f1)
|
||||
|
||||
# Return average F1 across all ground truths
|
||||
return sum(f1_scores) / len(f1_scores)
|
||||
|
||||
|
||||
def get_category_name(item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract and normalize category name from QA item.
|
||||
|
||||
Handles both numeric categories (1-4) and string categories with various formats.
|
||||
Supports multiple field names: "cat", "category", "type".
|
||||
|
||||
Category mapping:
|
||||
- 1 or "multi-hop" -> "Multi-Hop"
|
||||
- 2 or "temporal" -> "Temporal"
|
||||
- 3 or "open domain" -> "Open Domain"
|
||||
- 4 or "single-hop" -> "Single-Hop"
|
||||
|
||||
Args:
|
||||
item: QA item dictionary containing category information
|
||||
|
||||
Returns:
|
||||
Standardized category name or "unknown" if not found
|
||||
|
||||
Examples:
|
||||
>>> get_category_name({"category": 1})
|
||||
'Multi-Hop'
|
||||
>>> get_category_name({"cat": "temporal"})
|
||||
'Temporal'
|
||||
>>> get_category_name({"type": "Single-Hop"})
|
||||
'Single-Hop'
|
||||
"""
|
||||
# Numeric category mapping
|
||||
CATEGORY_MAP = {
|
||||
1: "Multi-Hop",
|
||||
2: "Temporal",
|
||||
3: "Open Domain",
|
||||
4: "Single-Hop",
|
||||
}
|
||||
|
||||
# String category aliases (case-insensitive)
|
||||
TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
# Try "cat" field first (string category)
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return TYPE_ALIASES.get(lower, name)
|
||||
|
||||
# Try "category" field (can be int or string)
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
||||
elif isinstance(cat_num, str) and cat_num.strip():
|
||||
lower = cat_num.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
||||
|
||||
# Try "type" field as fallback
|
||||
cat_type = item.get("type")
|
||||
if isinstance(cat_type, str) and cat_type.strip():
|
||||
lower = cat_type.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
||||
|
||||
return "unknown"
|
||||
796
api/app/core/memory/evaluation/locomo/locomo_test.py
Normal file
796
api/app/core/memory/evaluation/locomo/locomo_test.py
Normal file
@@ -0,0 +1,796 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import f1_score, bleu1, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
# 回退到本地实现
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
|
||||
r_counts = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
|
||||
clipped = 0
|
||||
p_counts = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
|
||||
return bp * precision
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p = set(_loc_normalize(pred_str).split())
|
||||
r = set(_loc_normalize(ref_str).split())
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
# 添加 evaluation 目录路径
|
||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
||||
if evaluation_dir not in sys.path:
|
||||
sys.path.insert(0, evaluation_dir)
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
t = str(text) if text is not None else ""
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
p_tokens = _loc_normalize(prediction).split()
|
||||
g_tokens = _loc_normalize(ground_truth).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
||||
"""根据问题复杂度和进度动态调整检索参数"""
|
||||
|
||||
# 分析问题复杂度
|
||||
word_count = len(question.split())
|
||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
||||
|
||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
||||
progress_factor = question_index / total_questions
|
||||
|
||||
base_limit = 12
|
||||
if has_temporal and has_multi_hop:
|
||||
base_limit = 20
|
||||
elif word_count > 8:
|
||||
base_limit = 16
|
||||
|
||||
# 随着测试进行,逐渐收紧检索范围
|
||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
||||
|
||||
# 动态调整最大字符数
|
||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
||||
|
||||
return {
|
||||
"limit": adjusted_limit,
|
||||
"max_chars": int(max_chars)
|
||||
}
|
||||
|
||||
|
||||
class EnhancedEvaluationMonitor:
|
||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
||||
self.question_count = 0
|
||||
self.reset_interval = reset_interval
|
||||
self.performance_threshold = performance_threshold
|
||||
self.consecutive_low_scores = 0
|
||||
self.performance_history = []
|
||||
self.recent_f1_scores = []
|
||||
|
||||
def should_reset_connections(self, current_f1=None):
|
||||
"""基于计数和性能双重判断"""
|
||||
# 定期重置
|
||||
if self.question_count % self.reset_interval == 0:
|
||||
return True
|
||||
|
||||
# 性能驱动的重置
|
||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
||||
self.consecutive_low_scores += 1
|
||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
||||
print("🚨 连续低分,触发紧急重置")
|
||||
self.consecutive_low_scores = 0
|
||||
return True
|
||||
else:
|
||||
self.consecutive_low_scores = 0
|
||||
|
||||
return False
|
||||
|
||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
||||
"""记录性能指标,检测衰减"""
|
||||
self.performance_history.append({
|
||||
'index': question_index,
|
||||
'metrics': metrics,
|
||||
'context_length': context_length,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# 记录最近的F1分数
|
||||
self.recent_f1_scores.append(metrics['f1'])
|
||||
if len(self.recent_f1_scores) > 5:
|
||||
self.recent_f1_scores.pop(0)
|
||||
|
||||
def get_recent_performance(self):
|
||||
"""获取近期平均性能"""
|
||||
if not self.recent_f1_scores:
|
||||
return 0.5
|
||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
||||
|
||||
def get_performance_trend(self):
|
||||
"""分析性能趋势"""
|
||||
if len(self.performance_history) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
||||
|
||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
||||
|
||||
if recent_avg < earlier_avg * 0.8:
|
||||
return "degrading"
|
||||
elif recent_avg > earlier_avg * 1.1:
|
||||
return "improving"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
||||
|
||||
# 基础参数
|
||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
||||
|
||||
# 性能自适应调整
|
||||
if recent_performance < 0.5: # 近期表现差
|
||||
# 增加检索范围,尝试获取更多上下文
|
||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
elif recent_performance > 0.8: # 近期表现好
|
||||
# 收紧检索,提高精度
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
# 中间阶段特殊处理
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用更精确的检索策略")
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
||||
|
||||
return base_params
|
||||
|
||||
|
||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
||||
"""考虑问题序列位置的智能选择"""
|
||||
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 在序列中间阶段使用更严格的筛选
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
||||
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用严格上下文筛选")
|
||||
|
||||
# 提取问题关键词
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
# 只保留高度相关的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts:
|
||||
context_lower = context.lower()
|
||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
||||
|
||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
||||
if any(char.isdigit() for char in context):
|
||||
relevance_score += 2
|
||||
|
||||
# 提高阈值:只有得分>=3的上下文才保留
|
||||
if relevance_score >= 3:
|
||||
filtered_contexts.append(context)
|
||||
else:
|
||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
||||
|
||||
contexts = filtered_contexts
|
||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
||||
|
||||
# 使用原有的智能选择逻辑
|
||||
return smart_context_selection(contexts, question, max_chars)
|
||||
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
current_file = os.path.abspath(__file__)
|
||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
qa_items = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
items = qa_items[:20] # 测试多少个问题
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
|
||||
# 分类别统计
|
||||
category = "Unknown"
|
||||
if item.get("category") == 1:
|
||||
category = "Multi-Hop"
|
||||
elif item.get("category") == 2:
|
||||
category = "Temporal"
|
||||
elif item.get("category") == 3:
|
||||
category = "Open Domain"
|
||||
elif item.get("category") == 4:
|
||||
category = "Single-Hop"
|
||||
|
||||
# 增强的检索参数
|
||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
||||
search_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
||||
|
||||
# 使用项目标准的混合检索方法
|
||||
t0 = time.time()
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用统一的搜索服务
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
print("🔀 使用混合搜索服务...")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
group_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
embedding_id=SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
search_time = (t1 - t0) * 1000
|
||||
|
||||
# 增强的上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
# 使用增强的上下文选择
|
||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览(不只是第一条)
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
# 🔍 调试:检查答案是否在上下文中
|
||||
if ref_str and ref_str.strip():
|
||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# LLM 回答
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM 生成失败: {e}")
|
||||
pred = "Unknown"
|
||||
t3 = time.time()
|
||||
llm_time = (t3 - t2) * 1000
|
||||
|
||||
# 计算指标 - 使用导入的指标函数
|
||||
f1_val = f1_score(pred, ref_str)
|
||||
bleu1_val = bleu1(pred, ref_str)
|
||||
jaccard_val = jaccard(pred, ref_str)
|
||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
||||
|
||||
# 更新统计
|
||||
total_f1 += f1_val
|
||||
total_bleu1 += bleu1_val
|
||||
total_jaccard += jaccard_val
|
||||
total_loc_f1 += loc_f1_val
|
||||
total_context_length += len(context_text)
|
||||
total_retrieved_docs += len(contexts_all)
|
||||
|
||||
if category not in category_stats:
|
||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
||||
|
||||
category_stats[category]["count"] += 1
|
||||
category_stats[category]["f1_sum"] += f1_val
|
||||
category_stats[category]["b1_sum"] += bleu1_val
|
||||
category_stats[category]["j_sum"] += jaccard_val
|
||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
||||
|
||||
# 记录性能指标
|
||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
||||
|
||||
# 保存结果
|
||||
question_result = {
|
||||
"question": q,
|
||||
"ground_truth": ref_str,
|
||||
"prediction": pred,
|
||||
"category": category,
|
||||
"metrics": metrics,
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": max_chars,
|
||||
"recent_performance": recent_performance
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_time,
|
||||
"llm_ms": llm_time
|
||||
}
|
||||
}
|
||||
|
||||
results["questions"].append(question_result)
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
if n > 0:
|
||||
results["overall_metrics"] = {
|
||||
"f1": total_f1 / n,
|
||||
"b1": total_bleu1 / n,
|
||||
"j": total_jaccard / n,
|
||||
"loc_f1": total_loc_f1 / n
|
||||
}
|
||||
|
||||
for category, stats in category_stats.items():
|
||||
count = stats["count"]
|
||||
results["category_metrics"][category] = {
|
||||
"count": count,
|
||||
"f1": stats["f1_sum"] / count,
|
||||
"bleu1": stats["b1_sum"] / count,
|
||||
"jaccard": stats["j_sum"] / count,
|
||||
"loc_f1": stats["loc_f1_sum"] / count
|
||||
}
|
||||
|
||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
||||
|
||||
# 分析性能趋势
|
||||
results["performance_trend"] = monitor.get_performance_trend()
|
||||
results["reset_interval"] = monitor.reset_interval
|
||||
results["total_questions_processed"] = monitor.question_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
||||
print("📋 增强特性:")
|
||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
||||
print(" - 连续性能监控:实时检测性能衰减")
|
||||
|
||||
result = asyncio.run(run_enhanced_evaluation())
|
||||
|
||||
print("\n📊 最终评估结果:")
|
||||
print("总体指标:")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
||||
|
||||
print("\n分类别指标:")
|
||||
for category, metrics in result['category_metrics'].items():
|
||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
||||
|
||||
print("\n检索统计:")
|
||||
stats = result['retrieval_stats']
|
||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
||||
|
||||
print(f"\n性能趋势: {result['performance_trend']}")
|
||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
||||
|
||||
|
||||
# 保存结果到指定目录
|
||||
# 使用代码文件所在目录的绝对路径
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_file_dir, "results")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n详细结果已保存到: {output_file}")
|
||||
626
api/app/core/memory/evaluation/locomo/locomo_utils.py
Normal file
626
api/app/core/memory/evaluation/locomo/locomo_utils.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""
|
||||
LoCoMo Utilities Module
|
||||
|
||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
||||
- Data loading from JSON files
|
||||
- Conversation extraction for ingestion
|
||||
- Temporal reference resolution
|
||||
- Context selection and formatting
|
||||
- Retrieval wrapper functions
|
||||
- Ingestion wrapper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
def load_locomo_data(
|
||||
data_path: str,
|
||||
sample_size: int,
|
||||
conversation_index: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load LoCoMo dataset from JSON file.
|
||||
|
||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
||||
object contains a "qa" list of question-answer pairs.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
||||
|
||||
Returns:
|
||||
List of QA item dictionaries, each containing:
|
||||
- question: str
|
||||
- answer: str
|
||||
- category: int (1-4)
|
||||
- evidence: List[str]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data_path does not exist
|
||||
json.JSONDecodeError: If file is not valid JSON
|
||||
IndexError: If conversation_index is out of range
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(raw, list):
|
||||
# Only load QA pairs from the specified conversation
|
||||
if conversation_index < len(raw):
|
||||
entry = raw[conversation_index]
|
||||
if isinstance(entry, dict) and "qa" in entry:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has {len(raw)} conversations."
|
||||
)
|
||||
else:
|
||||
# Fallback: single object with qa list
|
||||
if conversation_index == 0:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has only 1 conversation."
|
||||
)
|
||||
|
||||
# Return only the requested sample size
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
||||
so they can be ingested into the memory system. Each conversation is formatted
|
||||
as a multi-line string with "role: message" format.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
Each string contains multiple lines in format "role: message"
|
||||
|
||||
Example output:
|
||||
[
|
||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
||||
]
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# Ensure we have a list of entries
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
contents: List[str] = []
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
|
||||
if not isinstance(conv, dict):
|
||||
continue
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
# Collect all session_* messages
|
||||
for key, val in sorted(conv.items()):
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
for msg in val:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("speaker") or "User"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
|
||||
This function converts relative time expressions (like "today", "yesterday",
|
||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
||||
|
||||
Supported patterns:
|
||||
- today, yesterday, tomorrow
|
||||
- X days ago, in X days
|
||||
- last week, next week
|
||||
|
||||
Args:
|
||||
text: Text containing temporal references
|
||||
anchor_date: Reference date for resolution (datetime object)
|
||||
|
||||
Returns:
|
||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
||||
|
||||
Example:
|
||||
>>> anchor = datetime(2023, 5, 8)
|
||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
||||
"I saw him 2023-05-07"
|
||||
"""
|
||||
# Ensure input is a string
|
||||
t = str(text) if text is not None else ""
|
||||
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(
|
||||
r"\btoday\b",
|
||||
anchor_date.date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\byesterday\b",
|
||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\btomorrow\b",
|
||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# X days ago
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
||||
|
||||
# in X days
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(
|
||||
r"\b(\d+)\s+days?\s+ago\b",
|
||||
_ago_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bin\s+(\d+)\s+days?\b",
|
||||
_in_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# last week / next week (approximate as 7 days)
|
||||
t = re.sub(
|
||||
r"\blast\s+week\b",
|
||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def select_and_format_information(
|
||||
retrieved_info: List[str],
|
||||
question: str,
|
||||
max_chars: int = 8000
|
||||
) -> str:
|
||||
"""
|
||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
||||
|
||||
This function scores each piece of retrieved information based on keyword matching
|
||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
||||
|
||||
Scoring criteria:
|
||||
- Keyword matches (higher weight for multiple occurrences)
|
||||
- Context length (moderate length preferred)
|
||||
- Position (earlier contexts get bonus points)
|
||||
|
||||
Args:
|
||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
||||
question: Question being answered
|
||||
max_chars: Maximum total characters to include in final prompt
|
||||
|
||||
Returns:
|
||||
Formatted string combining the most relevant information for LLM prompt.
|
||||
Contexts are separated by double newlines.
|
||||
|
||||
Example:
|
||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
||||
>>> question = "Where did Alice go?"
|
||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
||||
"""
|
||||
if not retrieved_info:
|
||||
return ""
|
||||
|
||||
# Extract question keywords (filter out stop words and short words)
|
||||
question_lower = question.lower()
|
||||
stop_words = {
|
||||
'what', 'when', 'where', 'who', 'why', 'how',
|
||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
||||
}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {
|
||||
word for word in question_words
|
||||
if word not in stop_words and len(word) > 2
|
||||
}
|
||||
|
||||
# Score each context
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(retrieved_info):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# Keyword matching score
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# Multiple occurrences increase score
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# Length score (prefer moderate length)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000:
|
||||
score += 5
|
||||
elif context_len >= 2000:
|
||||
score += 2
|
||||
|
||||
# Position bonus (earlier contexts often more relevant)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# Sort by score (descending)
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Select contexts up to character limit
|
||||
selected = []
|
||||
total_chars = 0
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
else:
|
||||
# Try to include high-scoring context by truncating
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# Find lines with keywords
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
selected.append(truncated + "\n[Content truncated...]")
|
||||
total_chars += len(truncated)
|
||||
break
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
group_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
embedder: Any
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieve relevant information from memory graph for a question.
|
||||
|
||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
||||
returns relevant chunks, statements, and entity information that might help
|
||||
answer the question.
|
||||
|
||||
The function supports three search types:
|
||||
- "keyword": Full-text search using Cypher queries
|
||||
- "embedding": Vector similarity search using embeddings
|
||||
- "hybrid": Combination of keyword and embedding search with reranking
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
group_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
embedder: Embedder client instance
|
||||
|
||||
Returns:
|
||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
||||
Each string represents a piece of retrieved information.
|
||||
|
||||
Raises:
|
||||
Exception: If search fails (caught and returns empty list)
|
||||
"""
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# Embedding-based search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context from chunks
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add summaries
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities (limit to 3 to avoid noise)
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# Keyword-based search
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
|
||||
# Build context from dialogues
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add entity names
|
||||
if entities:
|
||||
entity_names = [
|
||||
str(e.get("name", "")).strip()
|
||||
for e in entities[:5]
|
||||
if e.get("name")
|
||||
]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid search with fallback to embedding
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# Handle flat structure (new API format)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Check if we got results
|
||||
if not (chunks or statements or entities or summaries):
|
||||
# Try nested structure (backward compatibility)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to embedding search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context (same for both hybrid and fallback)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
except Exception as e:
|
||||
# Return empty list on error
|
||||
contexts_all = []
|
||||
|
||||
return contexts_all
|
||||
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
group_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
group_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
group_id=group_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
858
api/app/core/memory/evaluation/locomo/qwen_search_eval.py
Normal file
858
api/app/core/memory/evaluation/locomo/qwen_search_eval.py
Normal file
@@ -0,0 +1,858 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
import re
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
def _loc_normalize(text: str) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
t = str(text) if text is not None else ""
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
# X days ago / in X days
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
# last week / next week(以7天近似)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
g_tokens = _loc_normalize(truth_str).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
||||
CATEGORY_MAP_NUM_TO_NAME = {
|
||||
4: "Single-Hop",
|
||||
1: "Multi-Hop",
|
||||
3: "Open Domain",
|
||||
2: "Temporal",
|
||||
}
|
||||
|
||||
_TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
def get_category_label(item: Dict[str, Any]) -> str:
|
||||
# 1) 直接用字符串 cat
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return _TYPE_ALIASES.get(lower, name)
|
||||
# 2) 数字 category 转名称
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
||||
# 3) 备用 type 字段
|
||||
t = item.get("type")
|
||||
if isinstance(t, str) and t.strip():
|
||||
lower = t.strip().lower()
|
||||
return _TYPE_ALIASES.get(lower, t.strip())
|
||||
return "unknown"
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_search_params_by_category(category: str):
|
||||
"""根据问题类别调整检索参数"""
|
||||
params_map = {
|
||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
||||
}
|
||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
||||
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 32,
|
||||
search_type: str = "hybrid", # 保持默认值不变
|
||||
output_path: str | None = None,
|
||||
skip_ingest_if_exists: bool = True,
|
||||
llm_timeout: float = 10.0,
|
||||
llm_max_retries: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
||||
|
||||
# === 保持原来的数据摄入逻辑 ===
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
# 只摄入前1条对话(保持原样)
|
||||
max_dialogues_to_ingest = 1
|
||||
contents: List[str] = []
|
||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
||||
|
||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
||||
|
||||
lines: List[str] = []
|
||||
if isinstance(conv, dict):
|
||||
# 收集所有 session_* 的消息
|
||||
session_count = 0
|
||||
for key, val in conv.items():
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
session_count += 1
|
||||
for msg in val:
|
||||
role = msg.get("speaker") or "用户"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
||||
|
||||
if not lines:
|
||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
||||
continue
|
||||
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
||||
|
||||
# 选择要评测的QA对(从所有对话中选取)
|
||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
||||
if isinstance(raw, list):
|
||||
for e_idx, entry in enumerate(raw):
|
||||
for qa in entry.get("qa", []):
|
||||
indexed_items.append((e_idx, qa))
|
||||
else:
|
||||
for qa in raw.get("qa", []):
|
||||
indexed_items.append((0, qa))
|
||||
|
||||
# 这里使用sample_size来限制评测的QA数量
|
||||
selected = indexed_items[:sample_size]
|
||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
||||
|
||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
||||
# === 修改结束 ===
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# connector initialized above
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 上下文诊断收集
|
||||
per_query_context_counts: List[int] = []
|
||||
per_query_context_avg_tokens: List[float] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_tokens_total: List[int] = []
|
||||
# 详细样本调试信息
|
||||
samples: List[Dict[str, Any]] = []
|
||||
# 通用指标
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
||||
loc_f1s: List[float] = []
|
||||
# Per-category aggregation
|
||||
cat_counts: Dict[str, int] = {}
|
||||
cat_f1s: Dict[str, List[float]] = {}
|
||||
cat_b1s: Dict[str, List[float]] = {}
|
||||
cat_jss: Dict[str, List[float]] = {}
|
||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
||||
try:
|
||||
for item in items:
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
# 确保答案是字符串
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
cat = get_category_label(item)
|
||||
|
||||
print(f"\n=== 处理问题: {q} ===")
|
||||
|
||||
# 根据类别调整检索参数
|
||||
search_params = get_search_params_by_category(cat)
|
||||
adjusted_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
||||
|
||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
||||
t0 = time.time()
|
||||
contexts_all: List[str] = []
|
||||
search_results = None # 保存完整的检索结果
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# 直接调用嵌入检索,包含三路数据
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# 直接调用关键词检索
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
||||
|
||||
# 构建上下文
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
# 实体处理(关键词检索的实体可能没有分数)
|
||||
if entities:
|
||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
||||
print("🔀 使用混合检索(带回退机制)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
||||
# 新的API返回扁平结构,直接从顶层获取结果
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 新API返回扁平结构:直接从顶层获取
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
||||
else:
|
||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
|
||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts_all:
|
||||
content = str(context)
|
||||
# 排除包含当前问题标准答案的上下文
|
||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
||||
print("🚫 过滤掉包含标准答案的上下文")
|
||||
continue
|
||||
filtered_contexts.append(context)
|
||||
|
||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
||||
contexts_all = filtered_contexts
|
||||
|
||||
# 输出完整的检索结果信息
|
||||
print("🔍 检索结果详情:")
|
||||
if search_results:
|
||||
output_data = {
|
||||
"statements": [
|
||||
{
|
||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
||||
"score": s.get("score", 0.0)
|
||||
}
|
||||
for s in (statements[:2] if 'statements' in locals() else [])
|
||||
],
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"group_id": d.get("group_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"name": e.get("name", ""),
|
||||
"entity_type": e.get("entity_type", ""),
|
||||
"score": e.get("score", 0.0)
|
||||
}
|
||||
for e in (entities[:2] if 'entities' in locals() else [])
|
||||
]
|
||||
}
|
||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(" 无检索结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {search_type}检索失败: {e}")
|
||||
contexts_all = []
|
||||
search_results = None
|
||||
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 使用智能上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# 记录上下文诊断信息
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
||||
|
||||
# LLM 提示词
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
# 使用异步调用
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
|
||||
# 计算指标(确保使用字符串)
|
||||
f1_val = common_f1(str(pred), ref_str)
|
||||
b1_val = bleu1(str(pred), ref_str)
|
||||
j_val = jaccard(str(pred), ref_str)
|
||||
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
# Accumulate by category
|
||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
||||
cat_jss.setdefault(cat, []).append(j_val)
|
||||
|
||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
||||
if item.get("category") in [2, 3, 4]:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
elif item.get("category") in [1]:
|
||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
||||
else:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
loc_f1s.append(loc_val)
|
||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
||||
|
||||
# 保存完整的检索结果信息
|
||||
samples.append({
|
||||
"question": q,
|
||||
"answer": ref_str,
|
||||
"category": cat,
|
||||
"prediction": pred,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val,
|
||||
"loc_f1": loc_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": adjusted_limit,
|
||||
"max_chars": max_chars
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": (t1 - t0) * 1000,
|
||||
"llm_ms": (t3 - t2) * 1000
|
||||
}
|
||||
})
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {ref_str}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
||||
|
||||
# Compute per-category averages and dispersion (std, iqr)
|
||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
k = (len(sorted_vals) - 1) * p
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
||||
if f == c:
|
||||
return sorted_vals[f]
|
||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
||||
|
||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
||||
for c in cat_counts:
|
||||
f_list = cat_f1s.get(c, [])
|
||||
b_list = cat_b1s.get(c, [])
|
||||
j_list = cat_jss.get(c, [])
|
||||
lf_list = cat_loc_f1s.get(c, [])
|
||||
j_sorted = sorted(j_list)
|
||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
||||
j_q75 = _percentile(j_sorted, 0.75)
|
||||
j_q25 = _percentile(j_sorted, 0.25)
|
||||
by_category[c] = {
|
||||
"count": cat_counts[c],
|
||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
||||
"j_std": j_std,
|
||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
||||
# 参考 LoCoMo 评测的类别专用 F1
|
||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
||||
}
|
||||
|
||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
||||
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": sum(f1s) / max(len(f1s), 1),
|
||||
"b1": sum(b1s) / max(len(b1s), 1),
|
||||
"j": sum(jss) / max(len(jss), 1),
|
||||
# LoCoMo 类别专用 F1 的总体
|
||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
||||
},
|
||||
"by_category": by_category,
|
||||
"category_counts": cat_counts,
|
||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
||||
"context": {
|
||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
if output_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存结果失败: {e}")
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
output_path=args.output_path,
|
||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
||||
llm_timeout=args.llm_timeout,
|
||||
llm_max_retries=args.llm_max_retries
|
||||
))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 最终评测结果:")
|
||||
print(f" 样本数量: {result['items']}")
|
||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
||||
|
||||
if result['by_category']:
|
||||
print("\n📈 按类别细分:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" 样本数: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1344
api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py
Normal file
1344
api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py
Normal file
File diff suppressed because it is too large
Load Diff
1315
api/app/core/memory/evaluation/longmemeval/test_eval.py
Normal file
1315
api/app/core/memory/evaluation/longmemeval/test_eval.py
Normal file
File diff suppressed because it is too large
Load Diff
301
api/app/core/memory/evaluation/memsciqa/evaluate_qa.py
Normal file
301
api/app/core/memory/evaluation/memsciqa/evaluate_qa.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
||||
if not contexts:
|
||||
return ""
|
||||
import re
|
||||
# 提取问题关键词(移除停用词)
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
# 评分
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
||||
parts: List[str] = []
|
||||
for turn in dialog_obj.get("dialog", []):
|
||||
speaker = turn.get("speaker", "")
|
||||
text = turn.get("text", "")
|
||||
if text:
|
||||
parts.append(f"{speaker}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
||||
if results is None:
|
||||
return []
|
||||
emb = []
|
||||
kw = []
|
||||
if isinstance(results.get("embedding_search"), dict):
|
||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
||||
elif isinstance(results.get("dialogues"), list):
|
||||
emb = results.get("dialogues", []) or []
|
||||
if isinstance(results.get("keyword_search"), dict):
|
||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
||||
seen = set()
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for d in emb:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
for d in kw:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
contexts_used: List[str] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
try:
|
||||
for item in items:
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
||||
contexts_all: List[str] = []
|
||||
if results:
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重与限制
|
||||
seen_texts = set()
|
||||
for d in all_dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
for s in all_statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
# 实体摘要(最多3个)
|
||||
names = []
|
||||
merged_entities = all_entities[:]
|
||||
for e in merged_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
else:
|
||||
dialogs = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
for d in dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
|
||||
# 智能选择并截断到预算
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text[:200])
|
||||
|
||||
# Call LLM (使用异步调用)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
t2 = time.time()
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
# Aggregate metrics
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"accuracy": acc,
|
||||
# Placeholders for extensibility
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"avg_context_tokens": ctx_avg_tokens,
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
)
|
||||
)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
561
api/app/core/memory/evaluation/memsciqa/memsciqa-test.py
Normal file
561
api/app/core/memory/evaluation/memsciqa/memsciqa-test.py
Normal file
@@ -0,0 +1,561 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config_utils import get_embedder_config
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
tp = len(set(ps) & set(rs))
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
precision = tp / len(ps)
|
||||
recall = tp / len(rs)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
overlap = len([w for w in ps if w in rs])
|
||||
return overlap / max(len(ps), 1)
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
ps = set(pred.lower().split())
|
||||
rs = set(ref.lower().split())
|
||||
union = len(ps | rs)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return len(ps & rs) / union
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
||||
|
||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
||||
ql = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
||||
}
|
||||
words = re.findall(r"\b[\w-]+\b", ql)
|
||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
||||
# 去重保序
|
||||
seen = set()
|
||||
uniq = []
|
||||
for w in kws:
|
||||
if w not in seen:
|
||||
uniq.append(w)
|
||||
seen.add(w)
|
||||
if len(uniq) >= max_keywords:
|
||||
break
|
||||
return uniq
|
||||
|
||||
|
||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
||||
|
||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
||||
"""
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
tl = (ctx or "").lower()
|
||||
match_count = sum(1 for k in keywords if k in tl)
|
||||
length = len(ctx)
|
||||
score = match_count * 200 + min(length, 100000) / 100.0
|
||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
||||
return results[:max(top_n, 0)]
|
||||
|
||||
|
||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
||||
|
||||
|
||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
||||
items: List[Dict[str, Any]] = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
items.append(json.loads(line))
|
||||
except Exception:
|
||||
# 跳过坏行但不中断
|
||||
continue
|
||||
return items
|
||||
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 64,
|
||||
search_type: str = "embedding",
|
||||
data_path: str | None = None,
|
||||
start_index: int = 0,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
||||
|
||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
||||
- 支持 keyword / embedding / hybrid 三种检索
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
group_id = group_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
if not data_path:
|
||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
if os.path.exists(proj_path):
|
||||
data_path = proj_path
|
||||
elif os.path.exists(cwd_path):
|
||||
data_path = cwd_path
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
if sample_size is None or sample_size <= 0:
|
||||
items = all_items[start_index:]
|
||||
else:
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
llm = get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 评估循环
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 存储完整上下文文本用于统计
|
||||
contexts_used: List[str] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_counts: List[int] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
total_items = len(items)
|
||||
for idx, item in enumerate(items):
|
||||
if verbose:
|
||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
||||
results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
elif search_type == "keyword":
|
||||
# 关键词检索(直接调用 graph_search)
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
chunks = results.get("chunks", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
summaries = results.get("summaries", [])
|
||||
retrieved_counts = {
|
||||
"chunks": len(chunks),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
"summaries": len(summaries),
|
||||
}
|
||||
# 优先使用 chunks
|
||||
for c in chunks:
|
||||
text = str(c.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 statements
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 summaries
|
||||
for sm in summaries:
|
||||
text = str(sm.get("summary", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
||||
if contexts_all:
|
||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
||||
if analysis:
|
||||
print("📊 上下文相关性分析:")
|
||||
for a in analysis:
|
||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
||||
for i, ctx in enumerate(contexts_all[:10]):
|
||||
preview = str(ctx).replace("\n", " ")
|
||||
if len(preview) > 300:
|
||||
preview = preview[:300] + "..."
|
||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
||||
# 标注参考答案是否出现在任一上下文中
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower:
|
||||
hits = []
|
||||
for i, ctx in enumerate(contexts_all):
|
||||
if ref_lower in str(ctx).lower():
|
||||
hits.append(i+1)
|
||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
||||
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text)
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
|
||||
if verbose:
|
||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
||||
concat_preview = context_text.replace("\n", " ")
|
||||
if len(concat_preview) > 600:
|
||||
concat_preview = concat_preview[:600] + "..."
|
||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
||||
"3) Keep your answer brief and to the point;\n"
|
||||
"4) Do not add explanations or additional text beyond the answer."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
||||
if hasattr(resp, 'content'):
|
||||
pred = resp.content.strip()
|
||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
||||
pred = resp["choices"][0]["message"]["content"].strip()
|
||||
elif isinstance(resp, dict) and "content" in resp:
|
||||
pred = resp["content"].strip()
|
||||
elif isinstance(resp, str):
|
||||
pred = resp.strip()
|
||||
else:
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
||||
|
||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
||||
if pred.lower() in ["unknown", ""]:
|
||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
||||
except Exception as e:
|
||||
# 更详细的错误处理
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM调用异常: {e}")
|
||||
t3 = time.time()
|
||||
llm_ms = (t3 - t2) * 1000
|
||||
latencies_llm.append(llm_ms)
|
||||
|
||||
exact = exact_match(pred, reference)
|
||||
correct_flags.append(exact)
|
||||
f1_val = f1_score(str(pred), str(reference))
|
||||
b1_val = bleu1(str(pred), str(reference))
|
||||
j_val = jaccard(str(pred), str(reference))
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
if verbose:
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {reference}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
||||
|
||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
||||
samples.append({
|
||||
"question": str(question),
|
||||
"answer": str(reference),
|
||||
"prediction": str(pred),
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": context_char_budget
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_ms,
|
||||
"llm_ms": llm_ms
|
||||
}
|
||||
})
|
||||
|
||||
# 计算总体指标与聚合
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": ctx_avg_tokens,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_size = 0 if args.all else args.sample_size
|
||||
|
||||
verbose_flag = False if args.quiet else args.verbose
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
data_path=args.data_path,
|
||||
start_index=args.start_index,
|
||||
verbose=verbose_flag,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果保存
|
||||
out_path = args.output
|
||||
if not out_path:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 结果已保存: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 结果保存失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
150
api/app/core/memory/evaluation/run_eval.py
Normal file
150
api/app/core/memory/evaluation/run_eval.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
||||
|
||||
|
||||
async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
group_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
llm_temperature: float | None = None,
|
||||
llm_max_tokens: int | None = None,
|
||||
search_type: str | None = None,
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(group_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
if start_index is not None:
|
||||
kwargs["start_index"] = start_index
|
||||
if max_contexts_per_item is not None:
|
||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
||||
return await run_longmemeval_test(**kwargs)
|
||||
raise ValueError(f"未知数据集: {dataset}")
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
||||
# 仅透传到 longmemeval;其他数据集忽略
|
||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run(
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.group_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
args.llm_temperature,
|
||||
args.llm_max_tokens,
|
||||
args.search_type,
|
||||
args.start_index,
|
||||
args.max_contexts_per_item,
|
||||
))
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果输出逻辑保持不变
|
||||
if args.output:
|
||||
out_path = args.output
|
||||
else:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
||||
|
||||
out_dir = os.path.dirname(out_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
MemSci 记忆系统主入口
|
||||
MemSci 记忆系统主入口 - 重构版本
|
||||
|
||||
该模块是重构后的记忆系统主入口,使用新的模块化架构。
|
||||
旧版本入口(app/core/memory/src/main.py)已删除。
|
||||
|
||||
主要功能:
|
||||
1. 协调整个知识提取流水线
|
||||
@@ -319,7 +322,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
logger.info(f"Timing details saved to: {log_file}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"✓ 流水线执行完成")
|
||||
print("✓ 流水线执行完成")
|
||||
print(f"✓ 总耗时: {total_time:.2f} 秒")
|
||||
print(f"✓ 详细日志: {log_file}")
|
||||
print("=" * 60)
|
||||
|
||||
@@ -18,6 +18,10 @@ class EntityDedupDecision(BaseModel):
|
||||
This model represents the LLM's decision on whether two entities
|
||||
refer to the same real-world entity and should be merged.
|
||||
|
||||
Note: Aliases are extracted during the triplet extraction phase and automatically
|
||||
merged during entity merging. LLM only needs to decide whether to merge and which
|
||||
entity to keep as canonical.
|
||||
|
||||
Attributes:
|
||||
same_entity: Whether the two entities refer to the same real-world entity
|
||||
confidence: Model confidence in the decision (0.0 to 1.0)
|
||||
@@ -36,6 +40,10 @@ class EntityDisambDecision(BaseModel):
|
||||
This model represents the LLM's decision on whether two entities with
|
||||
the same name but different types should be merged or kept separate.
|
||||
|
||||
Note: Aliases are extracted during the triplet extraction phase and automatically
|
||||
merged during entity merging. LLM only needs to decide whether to merge and which
|
||||
entity to keep as canonical.
|
||||
|
||||
Attributes:
|
||||
should_merge: Whether the two entities should be merged despite type difference
|
||||
confidence: Model confidence in the decision (0.0 to 1.0)
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
from app.core.memory.utils.alias_utils import validate_aliases
|
||||
|
||||
|
||||
def parse_historical_datetime(v):
|
||||
@@ -260,27 +261,66 @@ class ChunkNode(Node):
|
||||
|
||||
class ExtractedEntityNode(Node):
|
||||
"""Node representing an extracted entity in the knowledge graph.
|
||||
|
||||
This class represents entities extracted from dialogue statements. Each entity
|
||||
has a primary name and can have multiple aliases (alternative names). The aliases
|
||||
feature enables better entity deduplication and disambiguation by tracking all
|
||||
known names for an entity.
|
||||
|
||||
Attributes:
|
||||
entity_idx: Unique numeric identifier for the entity
|
||||
statement_id: ID of the statement this entity was extracted from
|
||||
entity_type: Type/category of the entity
|
||||
entity_type: Type/category of the entity (e.g., PERSON, ORGANIZATION, LOCATION)
|
||||
description: Textual description of the entity
|
||||
aliases: Optional list of alternative names for the entity
|
||||
aliases: List of alternative names for the entity. This field:
|
||||
- Stores all known alternative names in the SAME LANGUAGE as the entity name
|
||||
- Automatically filters out invalid values (None, empty strings)
|
||||
- Removes duplicates (case-insensitive) and names matching the primary name
|
||||
- Is used in fuzzy matching to improve entity deduplication
|
||||
- Is populated during triplet extraction and entity merging processes
|
||||
- Has a recommended maximum of 50 aliases per entity
|
||||
- CRITICAL: Aliases must be in the same language as the entity name (no translation)
|
||||
name_embedding: Optional embedding vector for the entity name
|
||||
fact_summary: Summary of facts about this entity
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this entity
|
||||
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
|
||||
config_id: Configuration ID used to process this entity (integer or string)
|
||||
"""
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
aliases: Optional[List[str]] = Field(default_factory=list, description="Entity aliases")
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(..., description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
It filters out:
|
||||
- None values
|
||||
- Empty strings
|
||||
- Non-string types (after converting to string)
|
||||
- Duplicate values
|
||||
|
||||
Args:
|
||||
v: The raw aliases value (can be None, list, or other types)
|
||||
|
||||
Returns:
|
||||
A cleaned list of unique string aliases
|
||||
|
||||
Example:
|
||||
>>> # Input: [None, "", "alias1", "alias1", 123]
|
||||
>>> # Output: ["alias1", "123"]
|
||||
"""
|
||||
return validate_aliases(v)
|
||||
|
||||
|
||||
class MemorySummaryNode(Node):
|
||||
|
||||
@@ -24,6 +24,8 @@ class Entity(BaseModel):
|
||||
name_embedding: Optional embedding vector for the entity name
|
||||
type: Type/category of the entity (e.g., 'Person', 'Organization')
|
||||
description: Textual description of the entity
|
||||
aliases: List of alternative names for the entity (e.g., abbreviations, full names,
|
||||
different language expressions). Extracted during triplet extraction phase.
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
@@ -35,6 +37,10 @@ class Entity(BaseModel):
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Alternative names for this entity (abbreviations, full names, translations, etc.)"
|
||||
)
|
||||
|
||||
|
||||
class Triplet(BaseModel):
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from chonkie import (
|
||||
SemanticChunker,
|
||||
RecursiveChunker,
|
||||
RecursiveRules,
|
||||
LateChunker,
|
||||
NeuralChunker,
|
||||
SentenceChunker,
|
||||
TokenChunker,
|
||||
)
|
||||
|
||||
from app.core.memory.models.config_models import ChunkerConfig
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
try:
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
except Exception:
|
||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||
OpenAIClient = Any
|
||||
|
||||
|
||||
class LLMChunker:
|
||||
"""基于LLM的智能分块策略"""
|
||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||
self.llm_client = llm_client
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def __call__(self, text: str) -> List[Any]:
|
||||
# 使用LLM分析文本结构并进行智能分块
|
||||
prompt = f"""
|
||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||
|
||||
文本内容:
|
||||
{text[:5000]}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
# 使用异步的 achat 方法
|
||||
if hasattr(self.llm_client, 'achat'):
|
||||
response = await self.llm_client.achat(messages)
|
||||
else:
|
||||
# 如果没有异步方法,使用同步方法并转换为异步
|
||||
response = await asyncio.to_thread(self.llm_client.chat, messages)
|
||||
|
||||
# 检查响应格式并提取内容
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
content = response.choices[0].message.content
|
||||
elif hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
# 解析LLM响应
|
||||
if "```json" in content:
|
||||
json_str = content.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in content:
|
||||
json_str = content.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
json_str = content
|
||||
|
||||
result = json.loads(json_str)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, index):
|
||||
self.text = text
|
||||
self.start_index = index * 100 # 近似位置
|
||||
self.end_index = (index + 1) * 100
|
||||
|
||||
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM分块失败: {e}")
|
||||
# 失败时返回空列表,外层会处理回退方案
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""混合分块策略:先按结构分块,再按语义合并"""
|
||||
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
|
||||
self.semantic_threshold = semantic_threshold
|
||||
self.base_chunk_size = base_chunk_size
|
||||
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
|
||||
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
|
||||
|
||||
def __call__(self, text: str) -> List[Any]:
|
||||
# 先用基础分块
|
||||
base_chunks = self.base_chunker(text)
|
||||
|
||||
# 如果文本不长,直接返回基础分块
|
||||
if len(base_chunks) <= 3:
|
||||
return base_chunks
|
||||
|
||||
# 对基础分块进行语义合并
|
||||
combined_text = " ".join([chunk.text for chunk in base_chunks])
|
||||
return self.semantic_chunker(combined_text)
|
||||
|
||||
|
||||
class ChunkerClient:
|
||||
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
|
||||
self.chunker_config = chunker_config
|
||||
self.embedding_model = chunker_config.embedding_model
|
||||
self.chunk_size = chunker_config.chunk_size
|
||||
self.threshold = chunker_config.threshold
|
||||
self.language = chunker_config.language
|
||||
self.skip_window = chunker_config.skip_window
|
||||
self.min_sentences = chunker_config.min_sentences
|
||||
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
|
||||
self.llm_client = llm_client
|
||||
|
||||
# 可选参数(从配置中安全获取,提供默认值)
|
||||
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
|
||||
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
|
||||
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
|
||||
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
|
||||
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
|
||||
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
|
||||
|
||||
# 初始化具体分块器策略
|
||||
if chunker_config.chunker_strategy == "TokenChunker":
|
||||
self.chunker = TokenChunker(
|
||||
tokenizer=self.tokenizer_or_token_counter,
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SemanticChunker":
|
||||
self.chunker = SemanticChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
threshold=self.threshold,
|
||||
chunk_size=self.chunk_size,
|
||||
min_sentences=self.min_sentences,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "RecursiveChunker":
|
||||
self.chunker = RecursiveChunker(
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk or 50,
|
||||
chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LateChunker":
|
||||
self.chunker = LateChunker(
|
||||
embedding_model=self.embedding_model,
|
||||
chunk_size=self.chunk_size,
|
||||
rules=RecursiveRules(),
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "NeuralChunker":
|
||||
self.chunker = NeuralChunker(
|
||||
model=self.embedding_model,
|
||||
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("LLMChunker requires an LLM client")
|
||||
self.chunker = LLMChunker(llm_client, self.chunk_size)
|
||||
elif chunker_config.chunker_strategy == "HybridChunker":
|
||||
self.chunker = HybridChunker(
|
||||
semantic_threshold=self.threshold,
|
||||
base_chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||
self.chunker = SentenceChunker(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
min_sentences_per_chunk=self.min_sentences_per_chunk,
|
||||
min_characters_per_sentence=self.min_characters_per_sentence,
|
||||
delim=self.delim,
|
||||
include_delim=self.include_delim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
|
||||
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
生成分块,支持异步操作
|
||||
"""
|
||||
try:
|
||||
# 预处理文本:确保对话标记格式统一
|
||||
content = dialogue.content
|
||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||
|
||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||
# 同步分块器
|
||||
chunks = self.chunker(content)
|
||||
else:
|
||||
# 异步分块器(如LLMChunker)
|
||||
chunks = await self.chunker(content)
|
||||
|
||||
# 过滤空块和过小的块
|
||||
valid_chunks = []
|
||||
for c in chunks:
|
||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||
valid_chunks.append(c)
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text if hasattr(c, 'text') else str(c),
|
||||
metadata={
|
||||
"start_index": getattr(c, "start_index", None),
|
||||
"end_index": getattr(c, "end_index", None),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
for c in valid_chunks
|
||||
]
|
||||
return dialogue
|
||||
|
||||
except Exception as e:
|
||||
print(f"分块失败: {e}")
|
||||
|
||||
# 改进的后备方案:尝试按对话回合分割
|
||||
try:
|
||||
# 简单的按对话分割
|
||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, start_index, end_index):
|
||||
self.text = text
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_start = 0
|
||||
|
||||
for match in matches:
|
||||
speaker, ct = match[0], match[1].strip()
|
||||
turn_text = f"{speaker} {ct}"
|
||||
|
||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
current_chunk = turn_text
|
||||
current_start = dialogue.content.find(turn_text, current_start)
|
||||
else:
|
||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text,
|
||||
metadata={
|
||||
"start_index": c.start_index,
|
||||
"end_index": c.end_index,
|
||||
"chunker_strategy": "DialogueTurnFallback",
|
||||
},
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
except Exception:
|
||||
# 最后的手段:单一大块
|
||||
dialogue.chunks = [Chunk(
|
||||
content=dialogue.content,
|
||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||
)]
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
"""
|
||||
评估分块质量
|
||||
"""
|
||||
if not getattr(dialogue, 'chunks', None):
|
||||
return {}
|
||||
|
||||
chunks = dialogue.chunks
|
||||
total_chars = sum(len(chunk.content) for chunk in chunks)
|
||||
avg_chunk_size = total_chars / len(chunks)
|
||||
|
||||
# 计算各种指标
|
||||
chunk_sizes = [len(chunk.content) for chunk in chunks]
|
||||
|
||||
metrics = {
|
||||
"strategy": self.chunker_config.chunker_strategy,
|
||||
"num_chunks": len(chunks),
|
||||
"total_characters": total_chars,
|
||||
"avg_chunk_size": avg_chunk_size,
|
||||
"min_chunk_size": min(chunk_sizes),
|
||||
"max_chunk_size": max(chunk_sizes),
|
||||
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
|
||||
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||
"""
|
||||
保存分块结果到文件,文件名包含策略名称
|
||||
"""
|
||||
strategy_name = self.chunker_config.chunker_strategy
|
||||
# 在文件名中添加策略名称
|
||||
base_name, ext = os.path.splitext(output_path)
|
||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||
|
||||
with open(strategy_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
|
||||
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
|
||||
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
f.write(f"Chunk {i+1}:\n")
|
||||
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||
f.write(f"Content: {chunk.content}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
print(f"Chunking results saved to: {strategy_output_path}")
|
||||
return strategy_output_path
|
||||
@@ -1,22 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
class EmbedderClient(ABC):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
self.config = model_config
|
||||
|
||||
self.model_name = model_config.model_name
|
||||
self.provider = model_config.provider
|
||||
self.api_key = model_config.api_key
|
||||
self.base_url = model_config.base_url
|
||||
self.max_retries = model_config.max_retries
|
||||
# self.dimension = model_config.dimension
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
) -> List[str]:
|
||||
pass
|
||||
@@ -1,37 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from app.core.memory.models.config_models import LLMConfig
|
||||
|
||||
"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
timeout: float = 30.0 # 请求超时时间(秒)
|
||||
max_retries: int = 3 # 最大重试次数
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
"""
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
class LLMClient(ABC):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
self.config = model_config
|
||||
|
||||
self.model_name = self.config.model_name
|
||||
self.provider = self.config.provider
|
||||
self.api_key = self.config.api_key
|
||||
self.base_url = self.config.base_url
|
||||
self.max_retries = self.config.max_retries
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, messages: List[Dict[str, str]]) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
) -> type[BaseModel]:
|
||||
pass
|
||||
@@ -1,224 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.llm import RedBearLLM
|
||||
from app.core.memory.src.llm_tools.llm_client import LLMClient
|
||||
# from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
|
||||
LANGFUSE_ENABLED=False
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"):
|
||||
super().__init__(model_config)
|
||||
|
||||
# Initialize Langfuse callback handler if enabled
|
||||
self.langfuse_handler = None
|
||||
if LANGFUSE_ENABLED:
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
self.langfuse_handler = CallbackHandler()
|
||||
except ImportError:
|
||||
# Langfuse not installed, continue without tracing
|
||||
pass
|
||||
except Exception as e:
|
||||
# Log error but don't fail initialization
|
||||
import logging
|
||||
logging.warning(f"Failed to initialize Langfuse handler: {e}")
|
||||
|
||||
# Initialize RedBearLLM client
|
||||
self.client = RedBearLLM(RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
max_retries=self.max_retries,
|
||||
), type=type_)
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]]) -> Any:
|
||||
template = """{messages}"""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | self.client
|
||||
|
||||
# Add Langfuse callback if available
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
# print(f"OpenAIClient response ======>:\n {response}")
|
||||
return response
|
||||
|
||||
async def response_structured(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_model: type[BaseModel],
|
||||
) -> type[BaseModel]:
|
||||
# Build a simple prompt pipeline that sends messages to the underlying LLM
|
||||
question_text = "\n\n".join([str(m.get("content", "")) for m in messages])
|
||||
|
||||
# Prepare config with Langfuse callback if available
|
||||
config = {}
|
||||
if self.langfuse_handler:
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
# Primary: enforce schema with PydanticOutputParser if available
|
||||
if PydanticOutputParser is not None:
|
||||
try:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
# 使用正确的属性路径:self.config.timeout(从LLMClient基类继承)
|
||||
# logger.info(f"开始LLM结构化输出请求 (模型: {self.model_name}, 超时: {self.config.timeout}秒)")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=response_model)
|
||||
format_instructions = parser.get_format_instructions()
|
||||
prompt = ChatPromptTemplate.from_template("{question}\n{format_instructions}")
|
||||
chain = prompt | self.client | parser
|
||||
parsed = await chain.ainvoke({
|
||||
"question": question_text,
|
||||
"format_instructions": format_instructions,
|
||||
})
|
||||
|
||||
# logger.info(f"LLM结构化输出请求成功完成")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"PydanticOutputParser失败,尝试备用方法: {str(e)}")
|
||||
# Fall through to alternative structured methods
|
||||
pass
|
||||
|
||||
# Fallback path: create plain prompt for other structured methods
|
||||
template = """{question}"""
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
# Try LangChain structured output if available on the underlying client
|
||||
try:
|
||||
with_so = getattr(self.client, "with_structured_output", None)
|
||||
|
||||
if callable(with_so):
|
||||
try:
|
||||
structured_chain = prompt | with_so(response_model, strict=True)
|
||||
parsed = await structured_chain.ainvoke({"question": question_text}, config=config)
|
||||
# parsed may already be a pydantic model or a dict
|
||||
try:
|
||||
return response_model.model_validate(parsed)
|
||||
except Exception:
|
||||
try:
|
||||
# If it's already a pydantic instance (LangChain returns model), return it
|
||||
if hasattr(parsed, "model_dump"):
|
||||
return parsed
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
except Exception:
|
||||
# Fall through to manual parsing below
|
||||
pass
|
||||
except NotImplementedError:
|
||||
# The underlying model doesn't support structured output, fall through
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Model {self.model_name} doesn't support with_structured_output, falling back to manual parsing")
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Structured output attempt failed: {e}, falling back to manual parsing")
|
||||
|
||||
# Final fallback: manual parsing with plain LLM response
|
||||
try:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Using manual parsing fallback for model {self.model_name}")
|
||||
|
||||
# Create a prompt that asks for JSON output
|
||||
json_prompt = ChatPromptTemplate.from_template(
|
||||
"{question}\n\n"
|
||||
"Please respond with a valid JSON object that matches this schema:\n"
|
||||
"{schema}\n\n"
|
||||
"Response (JSON only):"
|
||||
)
|
||||
|
||||
# Get the schema from the response model
|
||||
schema = response_model.model_json_schema()
|
||||
|
||||
chain = json_prompt | self.client
|
||||
response = await chain.ainvoke({
|
||||
"question": question_text,
|
||||
"schema": json.dumps(schema, indent=2)
|
||||
}, config=config)
|
||||
|
||||
# Extract JSON from response
|
||||
response_text = str(response.content if hasattr(response, 'content') else response)
|
||||
|
||||
# Try to find JSON in the response
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
try:
|
||||
parsed_dict = json.loads(json_str)
|
||||
return response_model.model_validate(parsed_dict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If JSON parsing fails, try to create a minimal valid response
|
||||
logger.warning(f"Failed to parse JSON from LLM response, creating minimal response")
|
||||
|
||||
# Create a minimal response based on the schema
|
||||
return self._create_minimal_response(response_model)
|
||||
|
||||
except Exception as fallback_error:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Manual parsing fallback also failed: {fallback_error}")
|
||||
# Return minimal response as last resort
|
||||
return self._create_minimal_response(response_model)
|
||||
|
||||
def _create_minimal_response(self, response_model: type[BaseModel]) -> BaseModel:
|
||||
"""Create a minimal valid response based on the model schema."""
|
||||
try:
|
||||
minimal_response = {}
|
||||
|
||||
for field_name, field_info in response_model.model_fields.items():
|
||||
# Check if field has a default value
|
||||
if hasattr(field_info, 'default') and field_info.default is not None:
|
||||
minimal_response[field_name] = field_info.default
|
||||
else:
|
||||
# Create default based on field type
|
||||
field_type = field_info.annotation
|
||||
|
||||
# Handle nested BaseModel
|
||||
if hasattr(field_type, '__bases__') and BaseModel in field_type.__bases__:
|
||||
minimal_response[field_name] = self._create_minimal_response(field_type)
|
||||
elif field_type == str:
|
||||
minimal_response[field_name] = "信息不足,无法回答"
|
||||
elif field_type == int:
|
||||
minimal_response[field_name] = 0
|
||||
elif field_type == float:
|
||||
minimal_response[field_name] = 0.0
|
||||
elif field_type == bool:
|
||||
minimal_response[field_name] = False
|
||||
elif field_type == list:
|
||||
minimal_response[field_name] = []
|
||||
elif field_type == dict:
|
||||
minimal_response[field_name] = {}
|
||||
else:
|
||||
minimal_response[field_name] = None
|
||||
|
||||
return response_model.model_validate(minimal_response)
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Failed to create minimal response: {e}")
|
||||
# Last resort: try to create with just required fields
|
||||
try:
|
||||
return response_model()
|
||||
except Exception:
|
||||
# If even that fails, raise the original error
|
||||
raise ValueError(f"Unable to create minimal response for {response_model.__name__}") from e
|
||||
@@ -1,26 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from app.core.memory.src.llm_tools.embedder_client import EmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
# from app.models.models_model import ModelType
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
|
||||
|
||||
class OpenAIEmbedderClient(EmbedderClient):
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
super().__init__(model_config)
|
||||
|
||||
async def response(
|
||||
self,
|
||||
messages: List[str],
|
||||
) -> List[List[float]]:
|
||||
texts: List[str] = [str(m) for m in messages if m is not None]
|
||||
|
||||
model = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
provider=self.provider,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
))
|
||||
embeddings = await model.aembed_documents(texts)
|
||||
return embeddings
|
||||
@@ -15,7 +15,7 @@ from app.repositories.neo4j.graph_search import (
|
||||
search_graph_by_temporal, search_graph_by_keyword_temporal,
|
||||
search_graph_by_chunk_id
|
||||
)
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.config_models import TemporalSearchParams
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
@@ -564,7 +564,7 @@ async def run_hybrid_search(
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning(f"Empty query after cleaning, returning empty results")
|
||||
logger.warning("Empty query after cleaning, returning empty results")
|
||||
return {
|
||||
"keyword_search": {},
|
||||
"embedding_search": {},
|
||||
|
||||
@@ -168,7 +168,7 @@ class DataPreprocessor:
|
||||
except json.JSONDecodeError as line_error:
|
||||
# 如果是单行巨大JSON数组,可能需要特殊处理
|
||||
if line_num == 1 and len(lines) == 1:
|
||||
print(f"检测到单行大型JSON,尝试分块解析...")
|
||||
print("检测到单行大型JSON,尝试分块解析...")
|
||||
# 对于超大单行JSON,尝试使用json.JSONDecoder进行流式解析
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
@@ -81,7 +81,6 @@ class SemanticPruner:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
|
||||
@@ -14,6 +14,51 @@ import difflib # 提供字符串相似度计算工具
|
||||
import asyncio
|
||||
import importlib
|
||||
import re
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
"""统一实体类型:基于LLM建议或启发式规则选择最合适的类型。
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留的实体)
|
||||
losing: 被合并的实体
|
||||
suggested_type: LLM建议的统一类型(可选)
|
||||
"""
|
||||
canonical_type = (getattr(canonical, "entity_type", "") or "").strip()
|
||||
losing_type = (getattr(losing, "entity_type", "") or "").strip()
|
||||
|
||||
if suggested_type and suggested_type.strip():
|
||||
# 优先使用LLM建议的类型
|
||||
canonical.entity_type = suggested_type.strip()
|
||||
elif canonical_type.upper() == "UNKNOWN" and losing_type.upper() != "UNKNOWN":
|
||||
# 如果canonical是UNKNOWN,使用losing的类型
|
||||
canonical.entity_type = losing_type
|
||||
elif canonical_type.upper() != "UNKNOWN" and losing_type.upper() == "UNKNOWN":
|
||||
# 如果losing是UNKNOWN,保持canonical的类型(无需操作)
|
||||
pass
|
||||
elif canonical_type and losing_type and canonical_type != losing_type:
|
||||
# 两个类型都不是UNKNOWN且不同,选择更具体的类型
|
||||
# 启发式规则:
|
||||
# 1. 更长的类型名通常更具体(如 HistoricalPeriod vs Organization)
|
||||
# 2. 包含特定领域词汇的类型更具体(如 MilitaryCapability vs Concept)
|
||||
|
||||
# 定义通用类型(优先级低)
|
||||
generic_types = {"Concept", "Phenomenon", "Condition", "State", "Attribute", "Event"}
|
||||
|
||||
canonical_is_generic = canonical_type in generic_types
|
||||
losing_is_generic = losing_type in generic_types
|
||||
|
||||
if canonical_is_generic and not losing_is_generic:
|
||||
# canonical是通用类型,losing是具体类型,使用losing
|
||||
canonical.entity_type = losing_type
|
||||
elif not canonical_is_generic and losing_is_generic:
|
||||
# losing是通用类型,canonical是具体类型,保持canonical(无需操作)
|
||||
pass
|
||||
elif len(losing_type) > len(canonical_type):
|
||||
# 两者都是具体类型或都是通用类型,选择更长的(通常更具体)
|
||||
canonical.entity_type = losing_type
|
||||
# 否则保持canonical的类型
|
||||
|
||||
|
||||
# 模块级属性融合工具函数(统一行为)
|
||||
def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
# 强弱连接合并
|
||||
@@ -30,18 +75,52 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
else:
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序)
|
||||
# 别名合并(去重保序,使用标准化工具)
|
||||
try:
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(existing)
|
||||
|
||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
||||
if incoming_name and incoming_name != canonical_name:
|
||||
all_aliases.append(incoming_name)
|
||||
|
||||
# 3. 添加incoming实体的所有别名
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
seen = set()
|
||||
merged_list: List[str] = []
|
||||
for x in existing + incoming:
|
||||
xn = (x or "").strip()
|
||||
if xn and xn not in seen:
|
||||
seen.add(xn)
|
||||
merged_list.append(x)
|
||||
canonical.aliases = merged_list
|
||||
all_aliases.extend(incoming)
|
||||
|
||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -132,25 +211,25 @@ def accurate_match(
|
||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
||||
if key not in canonical_map:
|
||||
canonical_map[key] = ent
|
||||
id_redirect[getattr(ent, "id")] = getattr(ent, "id")
|
||||
id_redirect[ent.id] = ent.id
|
||||
continue
|
||||
canonical = canonical_map[key]
|
||||
|
||||
# 执行精确属性与强弱合并,并建立重定向
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[getattr(ent, "id")] = getattr(canonical, "id")
|
||||
id_redirect[ent.id] = canonical.id
|
||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||
try:
|
||||
k = f"{getattr(canonical, 'group_id')}|{(getattr(canonical, 'name') or '').strip()}|{(getattr(canonical, 'entity_type') or '').strip()}"
|
||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": getattr(canonical, "id"),
|
||||
"group_id": getattr(canonical, "group_id"),
|
||||
"name": getattr(canonical, "name"),
|
||||
"entity_type": getattr(canonical, "entity_type"),
|
||||
"canonical_id": canonical.id,
|
||||
"group_id": canonical.group_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
}
|
||||
exact_merge_map[k]["merged_ids"].add(getattr(ent, "id"))
|
||||
exact_merge_map[k]["merged_ids"].add(ent.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -164,23 +243,33 @@ def fuzzy_match(
|
||||
config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
|
||||
"""
|
||||
模糊匹配:在精确匹配之后,基于名称/类型相似度与上下文共现,进一步融合高相似实体。
|
||||
模糊匹配:基于名称、别名、类型相似度进行实体去重合并。
|
||||
|
||||
判断因素:
|
||||
- 名称相似度(包含别名匹配):70%权重
|
||||
- 类型相似度:30%权重
|
||||
|
||||
返回: (updated_entities, updated_redirect, fuzzy_merge_records)
|
||||
"""
|
||||
fuzzy_merge_records: List[str] = []
|
||||
|
||||
# ========== 第一层:基础工具函数 ==========
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
"""文本标准化:转小写、去除特殊字符、规范化空格"""
|
||||
try:
|
||||
return re.sub(r"\s+", " ", re.sub(r"[^\w\u4e00-\u9fff]+", " ", (s or "").lower())).strip()
|
||||
except Exception:
|
||||
return str(s).lower().strip()
|
||||
|
||||
def _tokenize(s: str) -> List[str]:
|
||||
"""分词:提取中文字符和英文数字单词"""
|
||||
norm = _normalize_text(s)
|
||||
tokens = re.findall(r"[\u4e00-\u9fff]+|[a-z0-9]+", norm)
|
||||
return tokens
|
||||
|
||||
def _jaccard(a_tokens: List[str], b_tokens: List[str]) -> float:
|
||||
"""Jaccard相似度:计算两个token集合的交集/并集"""
|
||||
try:
|
||||
set_a, set_b = set(a_tokens), set(b_tokens)
|
||||
if not set_a and not set_b:
|
||||
@@ -192,10 +281,11 @@ def fuzzy_match(
|
||||
return 0.0
|
||||
|
||||
def _cosine(a: List[float], b: List[float]) -> float:
|
||||
"""余弦相似度:计算两个向量的夹角余弦值"""
|
||||
try:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||
na = sum(x * x for x in a) ** 0.5
|
||||
nb = sum(y * y for y in b) ** 0.5
|
||||
if na == 0 or nb == 0:
|
||||
@@ -204,44 +294,146 @@ def fuzzy_match(
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _name_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
|
||||
# ========== 第二层:中层工具函数 ==========
|
||||
|
||||
def _has_exact_alias_match(e1: ExtractedEntityNode, e2: ExtractedEntityNode) -> bool:
|
||||
"""检测两个实体之间是否存在完全别名匹配(case-insensitive)
|
||||
|
||||
检查以下情况:
|
||||
- e1的主名称与e2的某个别名完全匹配
|
||||
- e2的主名称与e1的某个别名完全匹配
|
||||
- e1和e2的别名列表有交集
|
||||
|
||||
Args:
|
||||
e1: 第一个实体
|
||||
e2: 第二个实体
|
||||
|
||||
Returns:
|
||||
bool: 存在完全匹配返回True
|
||||
"""
|
||||
def _simple_normalize(s: str) -> str:
|
||||
return (s or "").strip().lower()
|
||||
|
||||
# 获取e1的所有名称(主名称 + 别名)
|
||||
names1 = set()
|
||||
name1 = _simple_normalize(getattr(e1, "name", "") or "")
|
||||
if name1:
|
||||
names1.add(name1)
|
||||
|
||||
aliases1 = getattr(e1, "aliases", []) or []
|
||||
for alias in aliases1:
|
||||
normalized = _simple_normalize(alias)
|
||||
if normalized:
|
||||
names1.add(normalized)
|
||||
|
||||
# 获取e2的所有名称(主名称 + 别名)
|
||||
names2 = set()
|
||||
name2 = _simple_normalize(getattr(e2, "name", "") or "")
|
||||
if name2:
|
||||
names2.add(name2)
|
||||
|
||||
aliases2 = getattr(e2, "aliases", []) or []
|
||||
for alias in aliases2:
|
||||
normalized = _simple_normalize(alias)
|
||||
if normalized:
|
||||
names2.add(normalized)
|
||||
|
||||
# 检查是否有交集
|
||||
if names1 & names2:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ========== 第三层:高层综合函数 ==========
|
||||
|
||||
def _name_similarity_with_aliases(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
|
||||
"""名称相似度综合评分系统
|
||||
|
||||
综合考虑主名称和别名,计算两个实体的相似度。
|
||||
|
||||
算法:
|
||||
1. 计算主名称的向量相似度和Token Jaccard相似度
|
||||
2. 计算所有别名的Token Jaccard相似度
|
||||
3. 找出所有名称间的最佳匹配
|
||||
4. 使用 _has_exact_alias_match 检测是否存在完全匹配
|
||||
|
||||
评分权重:
|
||||
- 有完全匹配:embedding(40%) + primary_jaccard(20%) + max_alias_sim(40%)
|
||||
- 无完全匹配:embedding(60%) + primary_jaccard(20%) + max_alias_sim(20%)
|
||||
|
||||
Args:
|
||||
e1: 第一个实体
|
||||
e2: 第二个实体
|
||||
|
||||
Returns:
|
||||
tuple: (综合相似度, 向量相似度, 主名称Jaccard, 别名Jaccard,
|
||||
最佳别名匹配度, 是否完全匹配)
|
||||
"""
|
||||
# 1. 主名称向量相似度
|
||||
emb_sim = _cosine(getattr(e1, "name_embedding", []) or [], getattr(e2, "name_embedding", []) or [])
|
||||
|
||||
# 2. 主名称token相似度
|
||||
|
||||
# 2. 主名称token相似度
|
||||
tokens1 = set(_tokenize(getattr(e1, "name", "") or ""))
|
||||
tokens2 = set(_tokenize(getattr(e2, "name", "") or ""))
|
||||
j_primary = _jaccard(list(tokens1), list(tokens2))
|
||||
|
||||
# 3. 获取所有别名
|
||||
j_primary = _jaccard(list(tokens1), list(tokens2))
|
||||
|
||||
# 3. 获取所有别名
|
||||
aliases1 = getattr(e1, "aliases", []) or []
|
||||
aliases2 = getattr(e2, "aliases", []) or []
|
||||
|
||||
# 4. 计算所有别名的token集合(用于整体Jaccard)
|
||||
|
||||
# 4. 计算所有别名的token集合(用于整体Jaccard)
|
||||
alias_tokens1 = set(tokens1)
|
||||
alias_tokens2 = set(tokens2)
|
||||
for a in aliases1:
|
||||
alias_tokens1 |= set(_tokenize(a))
|
||||
for a in aliases2:
|
||||
alias_tokens2 |= set(_tokenize(a))
|
||||
j_primary = _jaccard(list(tokens1), list(tokens2))
|
||||
j_alias = _jaccard(list(alias_tokens1), list(alias_tokens2))
|
||||
s_name = 0.6 * emb_sim + 0.2 * j_primary + 0.2 * j_alias
|
||||
return s_name, emb_sim, j_primary, j_alias
|
||||
|
||||
def _desc_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode):
|
||||
"""
|
||||
计算实体描述的相似度(Jaccard + SequenceMatcher)
|
||||
返回: (相似度得分, Jaccard 相似度(词重合), SequenceMatcher 相似度(序列相似))
|
||||
"""
|
||||
d1 = getattr(e1, "description", "") or ""
|
||||
d2 = getattr(e2, "description", "") or ""
|
||||
if not d1 and not d2:
|
||||
return 0.0, 0.0, 0.0
|
||||
t1 = _tokenize(d1)
|
||||
t2 = _tokenize(d2)
|
||||
j = _jaccard(t1, t2)
|
||||
try:
|
||||
seq = difflib.SequenceMatcher(None, _normalize_text(d1), _normalize_text(d2)).ratio()
|
||||
except Exception:
|
||||
seq = 0.0
|
||||
# 平衡词重合与序列相似(更鲁棒)
|
||||
s_desc = 0.5 * j + 0.5 * seq
|
||||
return s_desc, j, seq
|
||||
|
||||
def _canonicalize_type(t: str) -> str: # 扩展类型同义归一
|
||||
|
||||
# 5. 使用 _has_exact_alias_match 检测完全匹配
|
||||
has_exact_match = _has_exact_alias_match(e1, e2)
|
||||
|
||||
# 6. 计算最佳别名匹配度(所有名称两两比较)
|
||||
all_names1 = [getattr(e1, "name", "") or "", *aliases1]
|
||||
all_names2 = [getattr(e2, "name", "") or "", *aliases2]
|
||||
|
||||
max_alias_sim = 0.0
|
||||
|
||||
if has_exact_match:
|
||||
max_alias_sim = 1.0
|
||||
else:
|
||||
for n1 in all_names1:
|
||||
if not n1:
|
||||
continue
|
||||
tokens_n1 = set(_tokenize(n1))
|
||||
|
||||
for n2 in all_names2:
|
||||
if not n2:
|
||||
continue
|
||||
|
||||
tokens_n2 = set(_tokenize(n2))
|
||||
sim = _jaccard(list(tokens_n1), list(tokens_n2))
|
||||
max_alias_sim = max(max_alias_sim, sim)
|
||||
|
||||
# 7. 综合评分
|
||||
if has_exact_match:
|
||||
s_name = 0.4 * emb_sim + 0.2 * j_primary + 0.4 * max_alias_sim
|
||||
else:
|
||||
s_name = 0.6 * emb_sim + 0.2 * j_primary + 0.2 * max_alias_sim
|
||||
|
||||
return s_name, emb_sim, j_primary, j_alias, max_alias_sim, has_exact_match
|
||||
|
||||
# ========== 类型相似度工具函数 ==========
|
||||
|
||||
def _canonicalize_type(t: str) -> str:
|
||||
"""类型标准化:将各种类型别名映射到规范类型"""
|
||||
t = (t or "").strip()
|
||||
if not t:
|
||||
return ""
|
||||
@@ -279,6 +471,7 @@ def fuzzy_match(
|
||||
return t_up
|
||||
|
||||
def _type_similarity(t1: str, t2: str) -> float:
|
||||
"""类型相似度:计算两个类型的相似度(基于规范化和相似度表)"""
|
||||
import difflib
|
||||
c1 = _canonicalize_type(t1)
|
||||
c2 = _canonicalize_type(t2)
|
||||
@@ -313,87 +506,196 @@ def fuzzy_match(
|
||||
t2n = (t2 or "").strip().lower()
|
||||
seq_ratio = difflib.SequenceMatcher(None, t1n, t2n).ratio()
|
||||
return seq_ratio * 0.6
|
||||
# 阈值与权重设定(从配置读取;若无配置则使用 DedupConfig 的默认值)
|
||||
# 阈值与权重设定
|
||||
_defaults = DedupConfig()
|
||||
|
||||
# 核心阈值
|
||||
T_NAME_STRICT = (config.fuzzy_name_threshold_strict if config is not None else _defaults.fuzzy_name_threshold_strict)
|
||||
T_TYPE_STRICT = (config.fuzzy_type_threshold_strict if config is not None else _defaults.fuzzy_type_threshold_strict)
|
||||
T_OVERALL = (config.fuzzy_overall_threshold if config is not None else _defaults.fuzzy_overall_threshold)
|
||||
UNKNOWN_NAME_T = (config.fuzzy_unknown_type_name_threshold if config is not None else _defaults.fuzzy_unknown_type_name_threshold)
|
||||
UNKNOWN_TYPE_T = (config.fuzzy_unknown_type_type_threshold if config is not None else _defaults.fuzzy_unknown_type_type_threshold)
|
||||
W_NAME = (config.name_weight if config is not None else _defaults.name_weight)
|
||||
W_DESC = (config.desc_weight if config is not None else _defaults.desc_weight)
|
||||
W_TYPE = (config.type_weight if config is not None else _defaults.type_weight)
|
||||
CTX_BONUS = (config.context_bonus if config is not None else _defaults.context_bonus) # 上下文共现加分
|
||||
FALL_FLOOR = (config.llm_fallback_floor if config is not None else _defaults.llm_fallback_floor)
|
||||
FALL_CEIL = (config.llm_fallback_ceiling if config is not None else _defaults.llm_fallback_ceiling)
|
||||
|
||||
# 权重:名称70%,类型30%
|
||||
W_NAME = 0.7
|
||||
W_TYPE = 0.3
|
||||
|
||||
|
||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||
""" 模糊匹配中的实体合并。
|
||||
|
||||
合并策略:
|
||||
1. 保留canonical的主名称不变
|
||||
2. 将losing的主名称添加为alias(如果不同)
|
||||
3. 合并两个实体的所有aliases
|
||||
4. 自动去重(case-insensitive)并排序
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留)
|
||||
losing: 被合并实体(删除)
|
||||
|
||||
Note:
|
||||
使用alias_utils.normalize_aliases进行标准化去重
|
||||
"""
|
||||
# 获取规范实体的名称
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
current_aliases = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(current_aliases)
|
||||
|
||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
||||
if losing_name and losing_name != canonical_name:
|
||||
all_aliases.append(losing_name)
|
||||
|
||||
# 3. 添加losing实体的所有别名
|
||||
losing_aliases = getattr(losing, "aliases", []) or []
|
||||
all_aliases.extend(losing_aliases)
|
||||
|
||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
# 使用标准化后的字符串作为key进行去重
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
|
||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||
i = 0
|
||||
while i < len(deduped_entities):
|
||||
a = deduped_entities[i]
|
||||
j = i + 1
|
||||
while j < len(deduped_entities):
|
||||
b = deduped_entities[j]
|
||||
|
||||
# 跳过不同业务组的实体
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
j += 1
|
||||
continue
|
||||
# 上下文共现
|
||||
try:
|
||||
sources_a = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(a, "id", None)}
|
||||
sources_b = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(b, "id", None)}
|
||||
co_ctx = bool(sources_a & sources_b)
|
||||
except Exception:
|
||||
co_ctx = False
|
||||
s_name, emb_sim, j_primary, j_alias = _name_similarity(a, b)
|
||||
s_desc, j_desc, seq_desc = _desc_similarity(a, b)
|
||||
|
||||
# ========== 第一步:计算相似度分数 ==========
|
||||
|
||||
# 1.1 名称+别名相似度(包含完全匹配检测)
|
||||
s_name, emb_sim, j_primary, j_alias, max_alias_sim, has_exact_match = _name_similarity_with_aliases(a, b)
|
||||
|
||||
# 1.2 类型相似度
|
||||
s_type = _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None))
|
||||
|
||||
# ========== 第二步:动态调整阈值 ==========
|
||||
|
||||
# 2.1 检测是否存在UNKNOWN类型
|
||||
unknown_present = (
|
||||
str(getattr(a, "entity_type", "")).upper() == "UNKNOWN"
|
||||
or str(getattr(b, "entity_type", "")).upper() == "UNKNOWN"
|
||||
)
|
||||
|
||||
# 2.2 根据类型设置名称阈值
|
||||
tn = UNKNOWN_NAME_T if unknown_present else T_NAME_STRICT
|
||||
tn = min(tn, 0.88) if co_ctx else tn
|
||||
|
||||
# 2.3 如果有完全别名匹配,降低名称相似度阈值
|
||||
if has_exact_match:
|
||||
tn = min(tn, 0.75)
|
||||
|
||||
# 2.4 设置类型阈值和综合阈值
|
||||
type_threshold = UNKNOWN_TYPE_T if unknown_present else T_TYPE_STRICT
|
||||
tover = T_OVERALL
|
||||
a_cs = (getattr(a, "connect_strength", "") or "").lower()
|
||||
b_cs = (getattr(b, "connect_strength", "") or "").lower()
|
||||
if a_cs in ("strong", "both") or b_cs in ("strong", "both"):
|
||||
tover = 0.80
|
||||
# 综合评分:名称、描述、类型加权 + 上下文加分
|
||||
overall = W_NAME * s_name + W_DESC * s_desc + W_TYPE * s_type + (CTX_BONUS if co_ctx else 0.0)
|
||||
|
||||
# ========== 第三步:计算综合评分 ==========
|
||||
# 公式:overall = 名称权重(70%) × 名称相似度 + 类型权重(30%) × 类型相似度
|
||||
overall = W_NAME * s_name + W_TYPE * s_type
|
||||
|
||||
# ========== 第四步:特殊规则判断(别名完全匹配快速通道)==========
|
||||
|
||||
# 4.1 检查主名称是否相同
|
||||
name_a_normalized = (getattr(a, "name", "") or "").strip().lower()
|
||||
name_b_normalized = (getattr(b, "name", "") or "").strip().lower()
|
||||
same_name = (name_a_normalized == name_b_normalized) and name_a_normalized != ""
|
||||
|
||||
# 4.2 别名匹配特殊规则(满足任一条件即可快速合并)
|
||||
alias_match_merge = False
|
||||
|
||||
# 规则1:别名完全匹配 + 类型相似度 ≥ 0.7
|
||||
if has_exact_match and s_type >= 0.7:
|
||||
alias_match_merge = True
|
||||
|
||||
# 规则2:名称相同 + 别名匹配 + 类型相似度 ≥ 0.5
|
||||
elif same_name and has_exact_match and s_type >= 0.5:
|
||||
alias_match_merge = True
|
||||
|
||||
# 规则3:名称相同 + 别名匹配 + 类型完全相同
|
||||
elif same_name and has_exact_match and s_type >= 1.0:
|
||||
alias_match_merge = True
|
||||
|
||||
if s_name >= tn and s_type >= type_threshold and overall >= tover:
|
||||
# ========== 第五步:最终合并判断 ==========
|
||||
# 满足以下任一条件即执行合并:
|
||||
# 条件A(快速通道):alias_match_merge = True
|
||||
# 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
|
||||
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
|
||||
# ========== 第六步:执行实体合并 ==========
|
||||
|
||||
# 6.1 合并别名
|
||||
_merge_entities_with_aliases(a, b)
|
||||
|
||||
# 6.2 合并其他属性(描述、事实摘要、时间范围等)
|
||||
_merge_attribute(a, b)
|
||||
|
||||
# 6.3 记录合并日志
|
||||
try:
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
fuzzy_merge_records.append(
|
||||
f"[模糊] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# 用于处理合并实体后,Statement节点下方无挂载边的情况 后续考虑将其代码逻辑统一由关系去重消歧管理
|
||||
# 建立 ID 重定向:将合并实体 b 的 ID 指向规范实体 a 的 ID
|
||||
|
||||
# 6.4 建立 ID 重定向映射
|
||||
try:
|
||||
canonical_id = id_redirect.get(getattr(a, "id", None), getattr(a, "id", None))
|
||||
losing_id = getattr(b, "id", None)
|
||||
if losing_id and canonical_id:
|
||||
# 将被合并实体的ID指向规范实体
|
||||
id_redirect[losing_id] = canonical_id
|
||||
# 扁平化可能的重定向链:凡是映射到 b.id 的,统一指向 a.id
|
||||
|
||||
# 扁平化重定向链:确保所有指向losing_id的映射都指向canonical_id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == losing_id:
|
||||
id_redirect[k] = canonical_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 6.5 从列表中移除被合并的实体
|
||||
deduped_entities.pop(j)
|
||||
continue
|
||||
continue # 不增加j,继续检查当前位置的下一个实体
|
||||
|
||||
# ========== 未达到合并条件:检查下一对 ==========
|
||||
else:
|
||||
try:
|
||||
if s_name >= tn and s_type >= type_threshold and (FALL_FLOOR <= overall < tover) and (overall <= FALL_CEIL):
|
||||
fuzzy_merge_records.append(
|
||||
f"[边界] {a.id}<->{b.id} ({a.group_id}|{a.name}|{a.entity_type} ~ {b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
j += 1
|
||||
j += 1 # 移动到下一个实体
|
||||
i += 1
|
||||
|
||||
return deduped_entities, id_redirect, fuzzy_merge_records
|
||||
@@ -428,24 +730,30 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
|
||||
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
|
||||
|
||||
# 动态导入 llm 客户端(统一从 app.core.memory.utils.llm_utils 获取)
|
||||
# 动态导入 llm 客户端(修正导入路径)
|
||||
try:
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm_utils")
|
||||
get_llm_client_fn = getattr(llm_utils_mod, "get_llm_client")
|
||||
except Exception:
|
||||
get_llm_client_fn = lambda: None
|
||||
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
|
||||
get_llm_client_fn = llm_utils_mod.get_llm_client
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
try:
|
||||
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
|
||||
llm_fn = getattr(llm_mod, "llm_dedup_entities_iterative_blocks")
|
||||
except Exception:
|
||||
raise RuntimeError("LLM 模块加载失败:deduplication.entity_dedup_llm 缺少 llm_dedup_entities_iterative_blocks")
|
||||
llm_fn = llm_mod.llm_dedup_entities_iterative_blocks
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
# 获取 LLM 客户端,若环境未配置或抛错则回退为 None
|
||||
# 获取 LLM 客户端
|
||||
try:
|
||||
llm_client = get_llm_client_fn()
|
||||
except Exception:
|
||||
llm_client = None
|
||||
if llm_client is None:
|
||||
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
except Exception as e:
|
||||
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
|
||||
return deduped_entities, id_redirect, llm_records
|
||||
|
||||
llm_redirect, llm_records = await llm_fn(
|
||||
entity_nodes=deduped_entities,
|
||||
@@ -527,7 +835,13 @@ async def LLM_disamb_decision(
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
|
||||
# 获取 LLM 客户端并验证
|
||||
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
if llm_client is None:
|
||||
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
|
||||
return deduped_entities, id_redirect, blocked_pairs, disamb_records
|
||||
|
||||
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
|
||||
entity_nodes=deduped_entities,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
@@ -708,7 +1022,7 @@ def _write_dedup_fusion_report(
|
||||
aggregated_exact_lines: List[str] = []
|
||||
try:
|
||||
for k, info in (exact_merge_map or {}).items():
|
||||
merged_ids = sorted(list(info.get("merged_ids", set())))
|
||||
merged_ids = sorted(info.get("merged_ids", set()))
|
||||
if merged_ids:
|
||||
aggregated_exact_lines.append(
|
||||
f"[精确] 键 {k} 规范实体 {info.get('canonical_id')} 名称 '{info.get('name')}' 类型 {info.get('entity_type')} <- 合并实体IDs {', '.join(merged_ids)}"
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Tuple, Dict
|
||||
import anyio
|
||||
|
||||
@@ -12,6 +14,12 @@ from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
|
||||
from app.core.memory.models.dedup_models import EntityDedupDecision, EntityDisambDecision
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_entity_dedup_prompt
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
_merge_attribute,
|
||||
_unify_entity_type
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- 类型同义归并与相似度 ---
|
||||
@@ -55,13 +63,37 @@ def _simple_type_ok(t1: str | None, t2: str | None) -> bool:
|
||||
return c1 == c2
|
||||
|
||||
|
||||
def parse_llm_response_safe(response_text: str, response_model) -> EntityDedupDecision | EntityDisambDecision | None:
|
||||
"""安全解析LLM响应,带错误处理。
|
||||
|
||||
Args:
|
||||
response_text: LLM返回的JSON文本
|
||||
response_model: 期望的响应模型类(EntityDedupDecision或EntityDisambDecision)
|
||||
|
||||
Returns:
|
||||
解析后的决策对象,如果解析失败则返回None
|
||||
"""
|
||||
try:
|
||||
data = json.loads(response_text)
|
||||
|
||||
# 使用Pydantic模型验证和解析
|
||||
return response_model(**data)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"LLM response JSON parsing failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response parsing failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _name_embed_sim(a: List[float] | None, b: List[float] | None) -> float: # 计算实体名称嵌入向量的余弦相似度
|
||||
a = a or []
|
||||
b = b or []
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
try:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||
na = (sum(x * x for x in a)) ** 0.5
|
||||
nb = (sum(y * y for y in b)) ** 0.5
|
||||
if na > 0 and nb > 0:
|
||||
@@ -174,6 +206,7 @@ async def _judge_pair(
|
||||
entity_b=entity_b,
|
||||
context=ctx,
|
||||
json_schema=EntityDedupDecision.model_json_schema(),
|
||||
disambiguation_mode=False, # 去重模式
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -290,6 +323,33 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||
continue
|
||||
|
||||
# 规则2.5:过滤掉应该在模糊匹配阶段就被合并的实体对
|
||||
# 如果名称相同且别名有交集,说明应该在模糊匹配阶段就被合并了
|
||||
# 这些实体对不应该进入LLM阶段,避免重复处理
|
||||
try:
|
||||
name_a = (getattr(a, "name", "") or "").strip().lower()
|
||||
name_b = (getattr(b, "name", "") or "").strip().lower()
|
||||
same_name = (name_a == name_b) and name_a != ""
|
||||
|
||||
if same_name:
|
||||
# 检查别名是否有交集
|
||||
names_a = {name_a}
|
||||
names_a |= {(alias or "").strip().lower() for alias in (getattr(a, "aliases", []) or [])}
|
||||
names_a.discard("")
|
||||
|
||||
names_b = {name_b}
|
||||
names_b |= {(alias or "").strip().lower() for alias in (getattr(b, "aliases", []) or [])}
|
||||
names_b.discard("")
|
||||
|
||||
has_alias_overlap = bool(names_a & names_b)
|
||||
|
||||
# 如果名称相同且别名有交集,跳过(应该在模糊匹配阶段处理)
|
||||
if has_alias_overlap:
|
||||
continue
|
||||
except Exception:
|
||||
pass # 如果检查失败,继续处理(保守策略)
|
||||
|
||||
# 规则3:名称相似度达标(文本/嵌入相似度取最大值)
|
||||
txt_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", ""))
|
||||
emb_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", []))
|
||||
@@ -317,6 +377,7 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
try:
|
||||
result_list[idx] = await _judge_pair(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Error judging pair ({i}, {j}): {e}", exc_info=True)
|
||||
result_list[idx] = e
|
||||
|
||||
# Limit concurrency using semaphore
|
||||
@@ -349,7 +410,12 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
canon_idx = decision.canonical_idx if decision.canonical_idx in (0, 1) else _choose_canonical(a, b)
|
||||
canon = a if canon_idx == 0 else b
|
||||
other = b if canon_idx == 0 else a
|
||||
id_redirect_updates[getattr(other, "id")] = getattr(canon, "id")
|
||||
|
||||
# 应用LLM合并决策:合并属性和统一类型
|
||||
_merge_attribute(canon, other)
|
||||
_unify_entity_type(canon, other, suggested_type=None)
|
||||
|
||||
id_redirect_updates[other.id] = canon.id
|
||||
records.append(
|
||||
f"[LLM合并] 规范实体 {canon.id} 名称 '{getattr(canon, 'name', '')}' <- 合并实体 {other.id} 名称 '{getattr(other, 'name', '')}' | conf={decision.confidence:.3f}, th={th:.3f}, co_ctx={ctx.get('co_occurrence')}"
|
||||
)
|
||||
@@ -508,8 +574,11 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
async def _run_block_wrapper(idx: int, block: List[ExtractedEntityNode]):
|
||||
try:
|
||||
results[idx] = await _run_one_block(idx, block)
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
logger.error(f"Error in block {idx}: {e}", exc_info=True)
|
||||
results[idx] = e
|
||||
if isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
|
||||
for i in range(len(blocks)):
|
||||
tg.start_soon(_run_block_wrapper, i, blocks[i])
|
||||
@@ -607,6 +676,7 @@ async def llm_disambiguate_pairs_iterative(
|
||||
try:
|
||||
judged[idx] = await _judge_pair_disamb(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in disamb pair ({i}, {j}): {e}", exc_info=True)
|
||||
judged[idx] = e
|
||||
|
||||
# Limit concurrency using semaphore
|
||||
@@ -634,6 +704,11 @@ async def llm_disambiguate_pairs_iterative(
|
||||
can_idx = 0 if decision.canonical_idx == 0 else 1
|
||||
canonical = a if can_idx == 0 else b
|
||||
losing = b if can_idx == 0 else a
|
||||
|
||||
# 应用LLM合并决策:合并属性和统一类型
|
||||
_merge_attribute(canonical, losing)
|
||||
_unify_entity_type(canonical, losing, suggested_type=decision.suggested_type)
|
||||
|
||||
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
|
||||
records.append(
|
||||
f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}"
|
||||
@@ -663,6 +738,11 @@ async def llm_disambiguate_pairs_iterative(
|
||||
sb = _strength_rank(getattr(b, "connect_strength", None))
|
||||
canonical = a if sa >= sb else b
|
||||
losing = b if sa >= sb else a
|
||||
|
||||
# 应用LLM合并决策:合并属性和统一类型
|
||||
_merge_attribute(canonical, losing)
|
||||
_unify_entity_type(canonical, losing, suggested_type=decision.suggested_type)
|
||||
|
||||
merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "")
|
||||
# 消歧合并审计
|
||||
records.append(
|
||||
|
||||
@@ -36,8 +36,8 @@ from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.src.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 导入各个提取模块
|
||||
@@ -349,7 +349,7 @@ class ExtractionOrchestrator:
|
||||
if all_responses:
|
||||
try:
|
||||
self.triplet_extractor.save_triplets(all_responses)
|
||||
logger.info(f"三元组数据已保存到文件")
|
||||
logger.info("三元组数据已保存到文件")
|
||||
except Exception as e:
|
||||
logger.error(f"保存三元组到文件失败: {e}", exc_info=True)
|
||||
|
||||
@@ -842,6 +842,7 @@ class ExtractionOrchestrator:
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = get_memory_logger(__name__)
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.message_models import DialogData, Statement, TemporalValidityRange
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_temporal_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, TemporalInfo
|
||||
@@ -218,5 +218,5 @@ class TemporalExtractor:
|
||||
f.write(f" - Valid At: {statement.temporal_validity.valid_at}\n")
|
||||
f.write(f" - Invalid At: {statement.temporal_validity.invalid_at}\n")
|
||||
else:
|
||||
f.write(f" - Temporal Validity: Not Extracted\n")
|
||||
f.write(" - Temporal Validity: Not Extracted\n")
|
||||
f.write("\n")
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from typing import List, Dict
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
|
||||
@@ -58,7 +58,7 @@ async def run_hybrid_search(
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Any, Dict, Tuple, List
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.config_models import LLMConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
314
api/app/core/memory/utils/alias_utils.py
Normal file
314
api/app/core/memory/utils/alias_utils.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Utility functions for entity alias management.
|
||||
|
||||
This module provides functions for validating, adding, merging, and normalizing
|
||||
entity aliases in the knowledge graph system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any, Dict, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_aliases(v: Any) -> List[str]:
|
||||
"""Validate and clean aliases field.
|
||||
|
||||
Filters out invalid values (None, empty strings, non-strings), removes duplicates,
|
||||
and ensures the field is always a list.
|
||||
|
||||
Args:
|
||||
v: The aliases value to validate
|
||||
|
||||
Returns:
|
||||
A cleaned list of unique string aliases
|
||||
"""
|
||||
if v is None:
|
||||
return []
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
|
||||
# Filter and clean: keep only valid strings, strip whitespace, remove duplicates
|
||||
seen = set()
|
||||
result = []
|
||||
for a in v:
|
||||
if a and isinstance(a, (str, int, float)):
|
||||
cleaned = str(a).strip()
|
||||
if cleaned and cleaned not in seen:
|
||||
seen.add(cleaned)
|
||||
result.append(cleaned)
|
||||
return result
|
||||
|
||||
|
||||
def add_alias(entity_name: str, current_aliases: List[str], new_alias: str) -> List[str]:
|
||||
"""Add a single alias to an entity's alias list.
|
||||
|
||||
Automatically handles deduplication and normalization. Ignores empty strings
|
||||
and aliases that match the entity's primary name.
|
||||
|
||||
Args:
|
||||
entity_name: The primary name of the entity
|
||||
current_aliases: Current list of aliases
|
||||
new_alias: The alias to add
|
||||
|
||||
Returns:
|
||||
Updated list of aliases
|
||||
"""
|
||||
if not new_alias or new_alias == entity_name:
|
||||
return current_aliases
|
||||
|
||||
normalized = new_alias.strip()
|
||||
if normalized and normalized not in current_aliases:
|
||||
return [*current_aliases, normalized]
|
||||
|
||||
return current_aliases
|
||||
|
||||
|
||||
def merge_aliases(entity_name: str, aliases1: List[str], aliases2: List[str]) -> List[str]:
|
||||
"""Merge two alias lists.
|
||||
|
||||
Automatically handles deduplication by adding each alias from the second list
|
||||
to the first list.
|
||||
|
||||
Args:
|
||||
entity_name: The primary name of the entity
|
||||
aliases1: First list of aliases
|
||||
aliases2: Second list of aliases to merge
|
||||
|
||||
Returns:
|
||||
Merged list of aliases without duplicates
|
||||
"""
|
||||
result = list(aliases1)
|
||||
for alias in aliases2:
|
||||
result = add_alias(entity_name, result, alias)
|
||||
return result
|
||||
|
||||
|
||||
def normalize_aliases(entity_name: str, aliases: List[str]) -> List[str]:
|
||||
"""Normalize an alias list.
|
||||
|
||||
Performs the following operations:
|
||||
- Removes duplicates (case-insensitive comparison)
|
||||
- Sorts alphabetically
|
||||
- Removes any aliases that match the primary name
|
||||
- Strips whitespace from all entries
|
||||
- Preserves the original case of the first occurrence
|
||||
|
||||
Args:
|
||||
entity_name: The primary name of the entity
|
||||
aliases: List of aliases to normalize
|
||||
|
||||
Returns:
|
||||
Normalized and sorted list of aliases
|
||||
"""
|
||||
# 使用字典来去重,key是小写形式,value是原始形式
|
||||
seen_normalized = {}
|
||||
entity_name_lower = entity_name.strip().lower()
|
||||
|
||||
for alias in aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped:
|
||||
continue
|
||||
|
||||
alias_lower = alias_stripped.lower()
|
||||
|
||||
# 跳过与主名称相同的别名(不区分大小写)
|
||||
if alias_lower == entity_name_lower:
|
||||
continue
|
||||
|
||||
# 如果这个别名(小写形式)还没见过,保存它
|
||||
if alias_lower not in seen_normalized:
|
||||
seen_normalized[alias_lower] = alias_stripped
|
||||
|
||||
# 返回排序后的唯一别名列表
|
||||
return sorted(seen_normalized.values())
|
||||
|
||||
|
||||
|
||||
# 错误处理相关常量
|
||||
MAX_ALIASES = 50 # 别名列表的最大数量限制
|
||||
|
||||
|
||||
def merge_aliases_with_limit(
|
||||
entity_name: str,
|
||||
aliases1: List[str],
|
||||
aliases2: List[str],
|
||||
max_aliases: int = MAX_ALIASES
|
||||
) -> List[str]:
|
||||
"""合并别名列表并限制数量。
|
||||
|
||||
当合并后的别名数量超过限制时,保留最相关的别名(基于长度,通常更短的别名更常用)。
|
||||
|
||||
Args:
|
||||
entity_name: 实体的主名称
|
||||
aliases1: 第一个别名列表
|
||||
aliases2: 第二个别名列表
|
||||
max_aliases: 最大别名数量限制(默认50)
|
||||
|
||||
Returns:
|
||||
合并后的别名列表,不超过max_aliases个
|
||||
"""
|
||||
# 合并所有别名
|
||||
all_aliases = list(set(aliases1 + aliases2))
|
||||
|
||||
# 移除与主名称相同的别名
|
||||
all_aliases = [a for a in all_aliases if a != entity_name]
|
||||
|
||||
# 如果超过限制,保留最短的别名(通常更常用)
|
||||
if len(all_aliases) > max_aliases:
|
||||
logger.warning(
|
||||
f"Aliases exceed limit ({len(all_aliases)} > {max_aliases}) for entity '{entity_name}', "
|
||||
f"truncating to {max_aliases} shortest aliases"
|
||||
)
|
||||
# 按长度排序,然后按字母顺序排序(确保稳定排序),保留最短的
|
||||
all_aliases = sorted(all_aliases, key=lambda x: (len(x), x))[:max_aliases]
|
||||
|
||||
# 最后按字母顺序排序返回
|
||||
return sorted(all_aliases)
|
||||
|
||||
|
||||
def detect_alias_cycles(entities: List[Any]) -> Dict[str, Set[str]]:
|
||||
"""检测实体别名中的循环引用。
|
||||
|
||||
构建别名图并检测循环:如果实体A的别名指向实体B,实体B的别名又指向实体A。
|
||||
|
||||
Args:
|
||||
entities: 实体列表,每个实体应有id、name和aliases属性
|
||||
|
||||
Returns:
|
||||
Dict[str, Set[str]]: 循环组的映射,key为组ID,value为该组中的实体ID集合
|
||||
"""
|
||||
# 构建名称到实体ID的映射(只映射主名称,不包括别名)
|
||||
name_to_entity: Dict[str, str] = {}
|
||||
entity_by_id: Dict[str, Any] = {}
|
||||
|
||||
for entity in entities:
|
||||
entity_id = getattr(entity, 'id', None)
|
||||
entity_name = getattr(entity, 'name', None)
|
||||
|
||||
if not entity_id or not entity_name:
|
||||
continue
|
||||
|
||||
entity_by_id[entity_id] = entity
|
||||
name_to_entity[entity_name.lower().strip()] = entity_id
|
||||
|
||||
# 构建实体间的连接图:如果实体A的别名匹配实体B的名称,则A指向B
|
||||
connections: Dict[str, Set[str]] = {}
|
||||
for entity in entities:
|
||||
entity_id = getattr(entity, 'id', None)
|
||||
entity_aliases = getattr(entity, 'aliases', []) or []
|
||||
|
||||
if not entity_id:
|
||||
continue
|
||||
|
||||
connections[entity_id] = set()
|
||||
|
||||
# 检查别名是否匹配其他实体的名称
|
||||
for alias in entity_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
normalized_alias = alias.lower().strip()
|
||||
if normalized_alias in name_to_entity:
|
||||
target_id = name_to_entity[normalized_alias]
|
||||
if target_id != entity_id:
|
||||
connections[entity_id].add(target_id)
|
||||
|
||||
# 使用DFS检测循环
|
||||
visited: Set[str] = set()
|
||||
rec_stack: Set[str] = set()
|
||||
cycles: Dict[str, Set[str]] = {}
|
||||
cycle_id = 0
|
||||
|
||||
def dfs(node: str, current_path: List[str]) -> None:
|
||||
"""深度优先搜索检测循环"""
|
||||
nonlocal cycle_id
|
||||
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
current_path.append(node)
|
||||
|
||||
for neighbor in connections.get(node, set()):
|
||||
if neighbor not in visited:
|
||||
dfs(neighbor, current_path)
|
||||
elif neighbor in rec_stack:
|
||||
# 发现循环
|
||||
cycle_start_idx = current_path.index(neighbor)
|
||||
cycle_nodes = {*current_path[cycle_start_idx:], neighbor}
|
||||
|
||||
# 记录循环
|
||||
cycle_key = f"cycle_{cycle_id}"
|
||||
cycles[cycle_key] = cycle_nodes
|
||||
cycle_id += 1
|
||||
|
||||
logger.warning(
|
||||
f"Detected alias cycle: {' -> '.join(current_path[cycle_start_idx:])} -> {neighbor}"
|
||||
)
|
||||
|
||||
current_path.pop()
|
||||
rec_stack.remove(node)
|
||||
|
||||
# 对所有节点执行DFS
|
||||
for entity_id in connections:
|
||||
if entity_id not in visited:
|
||||
dfs(entity_id, [])
|
||||
|
||||
return cycles
|
||||
|
||||
|
||||
def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> List[str]:
|
||||
"""解决别名循环引用。
|
||||
|
||||
对于检测到的循环,选择最强连接的实体作为规范实体,
|
||||
将循环中的其他实体合并到规范实体。
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
cycles: 循环组的映射(由detect_alias_cycles返回)
|
||||
|
||||
Returns:
|
||||
List[str]: 需要合并的实体ID列表(losing entity IDs)
|
||||
"""
|
||||
entity_by_id: Dict[str, Any] = {
|
||||
getattr(e, 'id', None): e for e in entities if getattr(e, 'id', None)
|
||||
}
|
||||
|
||||
merge_suggestions: List[str] = []
|
||||
|
||||
for cycle_key, cycle_entity_ids in cycles.items():
|
||||
if len(cycle_entity_ids) < 2:
|
||||
continue
|
||||
|
||||
# 选择规范实体:优先选择连接强度最高的
|
||||
def _strength_rank(entity_id: str) -> int:
|
||||
entity = entity_by_id.get(entity_id)
|
||||
if not entity:
|
||||
return 0
|
||||
strength = (getattr(entity, 'connect_strength', '') or '').lower()
|
||||
return {'strong': 3, 'both': 2, 'weak': 1}.get(strength, 0)
|
||||
|
||||
# 按连接强度排序
|
||||
sorted_entities = sorted(
|
||||
cycle_entity_ids,
|
||||
key=lambda eid: (
|
||||
_strength_rank(eid),
|
||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
canonical_id = sorted_entities[0]
|
||||
losing_ids = sorted_entities[1:]
|
||||
|
||||
logger.info(
|
||||
f"Resolving cycle {cycle_key}: canonical={canonical_id}, "
|
||||
f"merging={losing_ids}"
|
||||
)
|
||||
|
||||
merge_suggestions.extend(losing_ids)
|
||||
|
||||
return merge_suggestions
|
||||
@@ -46,7 +46,7 @@ def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
with open("logs/model_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"模型ID: {model_id}\n")
|
||||
f.write(f"模型配置信息:\n{model_config}\n")
|
||||
f.write(f"=============================\n\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
@@ -75,7 +75,7 @@ def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"嵌入模型ID: {embedding_id}\n")
|
||||
f.write(f"嵌入模型配置信息:\n{model_config}\n")
|
||||
f.write(f"=============================\n\n")
|
||||
f.write("=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_neo4j_config() -> dict:
|
||||
|
||||
@@ -273,7 +273,7 @@ def reload_configuration_from_database(config_id: int | str, force_reload: bool
|
||||
# 重新暴露常量
|
||||
_expose_runtime_constants(updated_cfg)
|
||||
|
||||
logger.info(f"[definitions] 配置重新加载成功,已暴露常量")
|
||||
logger.info("[definitions] 配置重新加载成功,已暴露常量")
|
||||
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
|
||||
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
|
||||
|
||||
|
||||
@@ -331,7 +331,7 @@ class LiteLLMConfig:
|
||||
'modules': {}
|
||||
}
|
||||
|
||||
for mod in self.module_stats.keys():
|
||||
for mod in self.module_stats:
|
||||
result['modules'][mod] = {
|
||||
'current_qps': self.module_stats[mod]['current_qps'],
|
||||
'max_qps': self.module_stats[mod]['max_qps'],
|
||||
@@ -394,7 +394,7 @@ class LiteLLMConfig:
|
||||
print(f"📊 {stats['message']}")
|
||||
return
|
||||
|
||||
print(f"\n📊 USAGE SUMMARY")
|
||||
print("\n📊 USAGE SUMMARY")
|
||||
print(f"{'='*50}")
|
||||
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
|
||||
print(f"📈 Requests: {stats['total_requests']}")
|
||||
@@ -404,7 +404,7 @@ class LiteLLMConfig:
|
||||
|
||||
# Module statistics
|
||||
if stats.get('module_stats'):
|
||||
print(f"\n📦 MODULES:")
|
||||
print("\n📦 MODULES:")
|
||||
for module, mod_stats in stats['module_stats'].items():
|
||||
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
|
||||
|
||||
@@ -479,7 +479,7 @@ def print_instant_qps(module: str = None):
|
||||
"""Print instant QPS information"""
|
||||
qps_data = get_instant_qps(module)
|
||||
|
||||
print(f"\n⚡ INSTANT QPS MONITOR")
|
||||
print("\n⚡ INSTANT QPS MONITOR")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if module:
|
||||
@@ -490,14 +490,14 @@ def print_instant_qps(module: str = None):
|
||||
else:
|
||||
# Global stats
|
||||
global_data = qps_data.get('global', {})
|
||||
print(f"🌍 GLOBAL:")
|
||||
print("🌍 GLOBAL:")
|
||||
print(f" Current QPS: {global_data.get('current_qps', 0)}")
|
||||
print(f" Max QPS: {global_data.get('max_qps', 0)}")
|
||||
|
||||
# Module stats
|
||||
modules = qps_data.get('modules', {})
|
||||
if modules:
|
||||
print(f"\n📦 MODULES:")
|
||||
print("\n📦 MODULES:")
|
||||
for mod, data in modules.items():
|
||||
print(f" {mod}:")
|
||||
print(f" Current: {data['current_qps']} QPS")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
@@ -35,26 +35,78 @@
|
||||
|
||||
===判定指引===
|
||||
{% if disambiguation_mode %}
|
||||
- 这是“同名但类型不同”的消歧场景。请判断两者是否指向同一真实世界实体。
|
||||
- 这是"同名但类型不同"的消歧场景。请判断两者是否指向同一真实世界实体。
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
|
||||
- **别名处理(高优先级)**:
|
||||
* 如果两个实体的别名列表中有交集,这是强烈的同一性信号
|
||||
* 如果一个实体的名称出现在另一个实体的别名中,应视为高置信度匹配
|
||||
* 如果一个实体的别名与另一个实体的名称完全匹配,应视为高置信度匹配
|
||||
* 别名匹配的权重应高于单纯的名称文本相似度
|
||||
- 若无法充分确定,应保守处理:不合并,并建议阻断该对在其他模糊/启发式合并中出现(block_pair=true)。
|
||||
- 若需要合并(should_merge=true),请选择“规范实体”(canonical_idx)并在可能的情况下给出建议统一类型(suggested_type),建议类型需与上下文一致。
|
||||
- 若需要合并(should_merge=true),请选择"规范实体"(canonical_idx)并**必须**给出建议统一类型(suggested_type)。
|
||||
- **类型统一原则(重要)**:
|
||||
* 优先选择更具体、更准确的类型(如 HistoricalPeriod 优于 Organization,MilitaryCapability 优于 Concept)
|
||||
* 如果两个类型都很具体但不同,选择与实体核心语义最匹配的类型
|
||||
* 通用类型(Concept、Phenomenon、Condition、State、Attribute、Event)优先级低于领域特定类型
|
||||
* 建议类型必须与上下文和实体描述一致
|
||||
- 规范实体优先级:连接强度(strong/both)更高者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
- **注意**:别名(aliases)已在三元组提取阶段获取,合并时会自动整合,无需在此阶段提取。
|
||||
{% else %}
|
||||
- 若实体类型相同或任一为UNKNOWN/空,可放行作为候选;若类型明显冲突(如人 vs 物品),除非别名与描述高度一致,否则判定不同实体。
|
||||
- **别名匹配优先(最高优先级)**:
|
||||
* 如果实体A的名称与实体B的某个别名完全匹配,应视为高置信度匹配
|
||||
* 如果实体B的名称与实体A的某个别名完全匹配,应视为高置信度匹配
|
||||
* 如果实体A的任一别名与实体B的任一别名完全匹配,应视为高置信度匹配
|
||||
* 别名完全匹配时,即使名称文本相似度较低,也应考虑合并
|
||||
* 别名匹配的置信度应高于单纯的名称相似度匹配
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要以及上下文关系判断是否为同一实体。
|
||||
- 当上下文同源或存在明确的关系陈述支持同一性(例如同一对象反复被提及或别名对应),可以适度降低判定阈值。
|
||||
- 保守决策:当无法充分确定,不要合并(same_entity=false)。
|
||||
- 若需要合并,选择“保留的规范实体”(canonical_idx)为更合适的一个:
|
||||
- 若需要合并,选择"保留的规范实体"(canonical_idx)为更合适的一个:
|
||||
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
- **注意**:别名(aliases)已在三元组提取阶段获取,合并时会自动整合,无需在此阶段提取。
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
{% if disambiguation_mode %}
|
||||
返回JSON格式,必须包含以下字段:
|
||||
{
|
||||
"should_merge": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"block_pair": boolean,
|
||||
"suggested_type": "string or null",
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**字段说明**:
|
||||
- should_merge: 是否应该合并这两个实体(true/false)
|
||||
- canonical_idx: 规范实体的索引,0表示实体A,1表示实体B
|
||||
- confidence: 决策的置信度,范围0.0-1.0
|
||||
- block_pair: 是否阻断该对在其他模糊/启发式合并中出现(true/false)
|
||||
- suggested_type: 建议的统一类型(字符串或null)
|
||||
- reason: 决策理由的简短说明
|
||||
{% else %}
|
||||
返回JSON格式,必须包含以下字段:
|
||||
{
|
||||
"same_entity": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**字段说明**:
|
||||
- same_entity: 两个实体是否指向同一真实世界实体(true/false)
|
||||
- canonical_idx: 规范实体的索引,0表示实体A,1表示实体B
|
||||
- confidence: 决策的置信度,范围0.0-1.0
|
||||
- reason: 决策理由的简短说明
|
||||
{% endif %}
|
||||
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
2. Ensure all JSON strings are properly closed and comma-separated
|
||||
3. Do not include line breaks within JSON string values
|
||||
4. Test your JSON output mentally to ensure it can be parsed correctly
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
{{ json_schema }}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -12,7 +12,18 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types and context-independent descriptions
|
||||
- Extract entities with their types, context-independent descriptions, and aliases
|
||||
- **Aliases Extraction (Important):**
|
||||
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
|
||||
* **DO NOT translate or add aliases in different languages**
|
||||
* Include common alternative names in the same language (e.g., "北京" → aliases: ["北平", "京城"])
|
||||
* Include abbreviations and full names in the same language (e.g., "联合国" → aliases: ["联合国组织"])
|
||||
* Include nicknames and common variations in the same language (e.g., "纽约" → aliases: ["纽约市", "大苹果"])
|
||||
* If no aliases exist in the same language, use empty array: []
|
||||
* **Examples:**
|
||||
- Chinese input "北京" → aliases: ["北平", "京城"] (NOT ["Beijing", "Peking"])
|
||||
- English input "Beijing" → aliases: ["Peking"] (NOT ["北京", "北平"])
|
||||
- Chinese input "苹果公司" → aliases: ["苹果"] (NOT ["Apple Inc.", "Apple"])
|
||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
||||
@@ -72,19 +83,22 @@ Output:
|
||||
"entity_idx": 0,
|
||||
"name": "I",
|
||||
"type": "Person",
|
||||
"description": "The user"
|
||||
"description": "The user",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Paris",
|
||||
"type": "Location",
|
||||
"description": "Capital city of France"
|
||||
"description": "Capital city of France",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "Louvre",
|
||||
"type": "Location",
|
||||
"description": "World-famous museum located in Paris"
|
||||
"description": "World-famous museum located in Paris",
|
||||
"aliases": ["Louvre Museum"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -115,19 +129,22 @@ Output:
|
||||
"entity_idx": 0,
|
||||
"name": "John Smith",
|
||||
"type": "Person",
|
||||
"description": "Individual person name"
|
||||
"description": "Individual person name",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Google",
|
||||
"type": "Organization",
|
||||
"description": "American technology company"
|
||||
"description": "American technology company",
|
||||
"aliases": ["Google LLC", "Alphabet Inc."]
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI product development",
|
||||
"type": "WorkRole",
|
||||
"description": "Artificial intelligence product development work"
|
||||
"description": "Artificial intelligence product development work",
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -158,19 +175,22 @@ Output:
|
||||
"entity_idx": 0,
|
||||
"name": "我",
|
||||
"type": "Person",
|
||||
"description": "用户本人"
|
||||
"description": "用户本人",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "巴黎",
|
||||
"type": "Location",
|
||||
"description": "法国首都城市"
|
||||
"description": "法国首都城市",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "卢浮宫",
|
||||
"type": "Location",
|
||||
"description": "位于巴黎的世界著名博物馆"
|
||||
"description": "位于巴黎的世界著名博物馆",
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -201,24 +221,27 @@ Output:
|
||||
"entity_idx": 0,
|
||||
"name": "张明",
|
||||
"type": "Person",
|
||||
"description": "个人姓名"
|
||||
"description": "个人姓名",
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "腾讯",
|
||||
"type": "Organization",
|
||||
"description": "中国科技公司"
|
||||
"description": "中国科技公司",
|
||||
"aliases": ["腾讯控股", "腾讯公司"]
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI产品开发",
|
||||
"type": "WorkRole",
|
||||
"description": "人工智能产品研发工作"
|
||||
"description": "人工智能产品研发工作",
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 5 (Entity Only):** "Tripod" or "三脚架"
|
||||
**Example 5 (Entity Only - English):** "Tripod"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
@@ -227,7 +250,23 @@ Output:
|
||||
"entity_idx": 0,
|
||||
"name": "Tripod",
|
||||
"type": "Equipment",
|
||||
"description": "Photography equipment accessory"
|
||||
"description": "Photography equipment accessory",
|
||||
"aliases": ["Camera Tripod"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (Entity Only - Chinese):** "三脚架"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "三脚架",
|
||||
"type": "Equipment",
|
||||
"description": "摄影器材配件",
|
||||
"aliases": ["相机三脚架"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -191,6 +191,9 @@ async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str:
|
||||
logging.info(f"成功删除 {success_count} 条检索数据")
|
||||
except Exception as e:
|
||||
logging.error(f"删除数据库中的检索数据失败: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
|
||||
async def _append_json(label: str, data: Any) -> None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user