[MODIFY] Code optimization
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""API Key Repository"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, func, and_
|
||||
from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, func, and_
|
||||
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.schemas import api_key_schema
|
||||
@@ -11,7 +12,7 @@ from app.schemas import api_key_schema
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""API Key 数据访问层"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, api_key_data: dict) -> ApiKey:
|
||||
"""创建 API Key"""
|
||||
@@ -19,27 +20,27 @@ class ApiKeyRepository:
|
||||
db.add(api_key)
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ApiKey]:
|
||||
"""根据 ID 获取 API Key"""
|
||||
return db.get(ApiKey, api_key_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
||||
"""根据哈希值获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
||||
return db.scalars(stmt).first()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_by_workspace(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
) -> Tuple[List[ApiKey], int]:
|
||||
"""列出工作空间的 API Keys"""
|
||||
stmt = select(ApiKey).where(ApiKey.workspace_id == workspace_id)
|
||||
|
||||
|
||||
# 过滤条件
|
||||
if query.type:
|
||||
stmt = stmt.where(ApiKey.type == query.type)
|
||||
@@ -47,40 +48,39 @@ class ApiKeyRepository:
|
||||
stmt = stmt.where(ApiKey.is_active == query.is_active)
|
||||
if query.resource_id:
|
||||
stmt = stmt.where(ApiKey.resource_id == query.resource_id)
|
||||
|
||||
|
||||
# 总数
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = db.execute(count_stmt).scalar()
|
||||
|
||||
|
||||
# 分页
|
||||
stmt = stmt.order_by(ApiKey.created_at.desc())
|
||||
stmt = stmt.offset((query.page - 1) * query.pagesize).limit(query.pagesize)
|
||||
|
||||
|
||||
items = db.scalars(stmt).all()
|
||||
return list(items), total
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey:
|
||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
|
||||
"""更新 API Key"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
for key, value in update_data.items():
|
||||
if value is not None:
|
||||
setattr(api_key, key, value)
|
||||
api_key.updated_at = datetime.datetime.now()
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""删除 API Key"""
|
||||
"""逻辑删除 API Key"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
db.delete(api_key)
|
||||
api_key.is_active = False
|
||||
db.flush()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""更新使用统计"""
|
||||
@@ -92,14 +92,14 @@ class ApiKeyRepository:
|
||||
db.flush()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_stats(db: Session, api_key_id: uuid.UUID) -> dict:
|
||||
"""获取使用统计"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
|
||||
# 今日请求数
|
||||
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_count_stmt = select(func.count()).select_from(ApiKeyLog).where(
|
||||
@@ -109,13 +109,13 @@ class ApiKeyRepository:
|
||||
)
|
||||
)
|
||||
requests_today = db.execute(today_count_stmt).scalar() or 0
|
||||
|
||||
|
||||
# 平均响应时间
|
||||
avg_time_stmt = select(func.avg(ApiKeyLog.response_time)).where(
|
||||
ApiKeyLog.api_key_id == api_key_id
|
||||
)
|
||||
avg_response_time = db.execute(avg_time_stmt).scalar()
|
||||
|
||||
|
||||
return {
|
||||
"total_requests": api_key.usage_count,
|
||||
"requests_today": requests_today,
|
||||
@@ -128,7 +128,7 @@ class ApiKeyRepository:
|
||||
|
||||
class ApiKeyLogRepository:
|
||||
"""API Key 日志数据访问层"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, log_data: dict) -> ApiKeyLog:
|
||||
"""创建日志"""
|
||||
@@ -136,3 +136,54 @@ class ApiKeyLogRepository:
|
||||
db.add(log)
|
||||
db.flush()
|
||||
return log
|
||||
|
||||
@staticmethod
|
||||
def list_by_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
filters: dict,
|
||||
page: int,
|
||||
pagesize: int
|
||||
) -> Tuple[List[ApiKeyLog], int]:
|
||||
"""
|
||||
根据 API Key ID 查询日志列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API Key ID
|
||||
filters: 过滤条件字典,支持:
|
||||
- start_date: 开始日期
|
||||
- end_date: 结束日期
|
||||
- status_code: HTTP 状态码
|
||||
- endpoint: 端点路径
|
||||
page: 页码
|
||||
pagesize: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[ApiKeyLog], int]: (日志列表, 总数)
|
||||
"""
|
||||
stmt = select(ApiKeyLog).where(ApiKeyLog.api_key_id == api_key_id)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters.get('start_date'):
|
||||
stmt = stmt.where(ApiKeyLog.created_at >= filters['start_date'])
|
||||
|
||||
if filters.get('end_date'):
|
||||
stmt = stmt.where(ApiKeyLog.created_at <= filters['end_date'])
|
||||
|
||||
if filters.get('status_code'):
|
||||
stmt = stmt.where(ApiKeyLog.status_code == filters['status_code'])
|
||||
|
||||
if filters.get('endpoint'):
|
||||
stmt = stmt.where(ApiKeyLog.endpoint.ilike(f"%{filters['endpoint']}%"))
|
||||
|
||||
# 计算总数
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = db.execute(count_stmt).scalar()
|
||||
|
||||
# 分页和排序
|
||||
stmt = stmt.order_by(ApiKeyLog.created_at.desc())
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
items = db.scalars(stmt).all()
|
||||
return list(items), total
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""基础仓储接口模块
|
||||
|
||||
本模块定义了通用的仓储接口,适用于所有数据库类型(PostgreSQL、Neo4j等)。
|
||||
@@ -14,7 +13,7 @@ from typing import Generic, TypeVar, List, Optional, Dict, Any
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseRepository(ABC, Generic[T]):
|
||||
class BaseRepository[T](ABC):
|
||||
"""基础仓储接口 - 适用于所有数据库类型
|
||||
|
||||
这是一个抽象基类,定义了所有仓储必须实现的基本CRUD操作。
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据配置Repository模块
|
||||
|
||||
本模块提供data_config表的数据访问层,包括SQL查询构建和Neo4j Cypher查询。
|
||||
从 app.core.memory.src.data_config_api.sql_queries 迁移而来。
|
||||
本模块提供data_config表的数据访问层,使用SQLAlchemy ORM进行数据库操作。
|
||||
包括CRUD操作和Neo4j Cypher查询常量。
|
||||
|
||||
Classes:
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作和查询构建
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple, List
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
import uuid
|
||||
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
@@ -24,15 +27,12 @@ from app.core.logging_config import get_db_logger
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
# 表名常量
|
||||
TABLE_NAME = "data_config"
|
||||
|
||||
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
- SQL查询构建(PostgreSQL)
|
||||
- SQLAlchemy ORM 数据库操作
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
|
||||
@@ -136,273 +136,369 @@ class DataConfigRepository:
|
||||
} AS targetNode
|
||||
"""
|
||||
|
||||
# ==================== SQL 查询构建方法 ====================
|
||||
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
|
||||
|
||||
@staticmethod
|
||||
def build_insert(params: ConfigParamsCreate) -> Tuple[str, Dict]:
|
||||
"""构建插入语句(PostgreSQL 命名参数)
|
||||
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
|
||||
"""创建数据配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
params: 配置参数创建模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
DataConfig: 创建的配置对象
|
||||
"""
|
||||
db_logger.debug(f"构建插入语句: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
columns = [
|
||||
"config_name",
|
||||
"config_desc",
|
||||
"workspace_id",
|
||||
"llm_id",
|
||||
"embedding_id",
|
||||
"rerank_id",
|
||||
"created_at",
|
||||
]
|
||||
placeholders = [
|
||||
"%(config_name)s",
|
||||
"%(config_desc)s",
|
||||
"%(workspace_id)s::uuid",
|
||||
"%(llm_id)s",
|
||||
"%(embedding_id)s",
|
||||
"%(rerank_id)s",
|
||||
"timezone('Asia/Shanghai', now())",
|
||||
]
|
||||
query = f"INSERT INTO {TABLE_NAME} (" + ",".join(columns) + ") VALUES (" + ",".join(placeholders) + ")"
|
||||
# 将 UUID 转换为字符串
|
||||
workspace_id_str = str(params.workspace_id) if params.workspace_id else None
|
||||
params_dict = {
|
||||
"config_name": params.config_name,
|
||||
"config_desc": params.config_desc,
|
||||
"workspace_id": workspace_id_str,
|
||||
"llm_id": params.llm_id,
|
||||
"embedding_id": params.embedding_id,
|
||||
"rerank_id": params.rerank_id,
|
||||
}
|
||||
return query, params_dict
|
||||
try:
|
||||
db_config = DataConfig(
|
||||
config_name=params.config_name,
|
||||
config_desc=params.config_desc,
|
||||
workspace_id=params.workspace_id,
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def build_update(update: ConfigUpdate) -> Tuple[str, Dict]:
|
||||
"""构建基础配置更新语句(PostgreSQL 命名参数)
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
|
||||
"""更新基础配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建更新语句: config_id={update.config_id}")
|
||||
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
"config_name": "config_name",
|
||||
"config_desc": "config_desc",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.config_name is not None:
|
||||
db_config.config_name = update.config_name
|
||||
has_update = True
|
||||
if update.config_desc is not None:
|
||||
db_config.config_desc = update.config_desc
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def build_update_extracted(update: ConfigUpdateExtracted) -> Tuple[str, Dict]:
|
||||
"""构建记忆萃取引擎配置更新语句(PostgreSQL 命名参数)
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
|
||||
"""更新记忆萃取引擎配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 萃取配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建萃取配置更新语句: config_id={update.config_id}")
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm",
|
||||
"embedding_id": "embedding",
|
||||
"rerank_id": "rerank",
|
||||
# 记忆萃取引擎
|
||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
||||
"deep_retrieval": "deep_retrieval",
|
||||
"t_type_strict": "t_type_strict",
|
||||
"t_name_strict": "t_name_strict",
|
||||
"t_overall": "t_overall",
|
||||
"state": "state",
|
||||
"chunker_strategy": "chunker_strategy",
|
||||
# 句子提取
|
||||
"statement_granularity": "statement_granularity",
|
||||
"include_dialogue_context": "include_dialogue_context",
|
||||
"max_context": "max_context",
|
||||
# 剪枝配置
|
||||
"pruning_enabled": "pruning_enabled",
|
||||
"pruning_scene": "pruning_scene",
|
||||
"pruning_threshold": "pruning_threshold",
|
||||
# 自我反思配置
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm",
|
||||
"embedding_id": "embedding_id",
|
||||
"rerank_id": "rerank_id",
|
||||
# 记忆萃取引擎
|
||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
||||
"deep_retrieval": "deep_retrieval",
|
||||
"t_type_strict": "t_type_strict",
|
||||
"t_name_strict": "t_name_strict",
|
||||
"t_overall": "t_overall",
|
||||
"state": "state",
|
||||
"chunker_strategy": "chunker_strategy",
|
||||
# 句子提取
|
||||
"statement_granularity": "statement_granularity",
|
||||
"include_dialogue_context": "include_dialogue_context",
|
||||
"max_context": "max_context",
|
||||
# 剪枝配置
|
||||
"pruning_enabled": "pruning_enabled",
|
||||
"pruning_scene": "pruning_scene",
|
||||
"pruning_threshold": "pruning_threshold",
|
||||
# 自我反思配置
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def build_update_forget(update: ConfigUpdateForget) -> Tuple[str, Dict]:
|
||||
"""构建遗忘引擎配置更新语句(PostgreSQL 命名参数)
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
|
||||
"""更新遗忘引擎配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建遗忘配置更新语句: config_id={update.config_id}")
|
||||
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
# 遗忘引擎
|
||||
"lambda_time": "lambda_time",
|
||||
"lambda_mem": "lambda_mem",
|
||||
# 由于 PostgreSQL 中 OFFSET 是保留字,需使用双引号包裹列名
|
||||
"offset": '"offset"',
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.lambda_time is not None:
|
||||
db_config.lambda_time = update.lambda_time
|
||||
has_update = True
|
||||
if update.lambda_mem is not None:
|
||||
db_config.lambda_mem = update.lambda_mem
|
||||
has_update = True
|
||||
if update.offset is not None:
|
||||
db_config.offset = update.offset
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def build_select_extracted(key: ConfigKey) -> Tuple[str, Dict]:
|
||||
"""构建萃取配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数)
|
||||
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
Args:
|
||||
key: 配置键模型
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"构建萃取配置查询语句: config_id={key.config_id}")
|
||||
# f"SELECT statement_granularity, include_dialogue_context, max_context, "
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
result = {
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
"rerank_id": db_config.rerank_id,
|
||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||
"deep_retrieval": db_config.deep_retrieval,
|
||||
"t_type_strict": db_config.t_type_strict,
|
||||
"t_name_strict": db_config.t_name_strict,
|
||||
"t_overall": db_config.t_overall,
|
||||
"chunker_strategy": db_config.chunker_strategy,
|
||||
"statement_granularity": db_config.statement_granularity,
|
||||
"include_dialogue_context": db_config.include_dialogue_context,
|
||||
"max_context": db_config.max_context,
|
||||
"pruning_enabled": db_config.pruning_enabled,
|
||||
"pruning_scene": db_config.pruning_scene,
|
||||
"pruning_threshold": db_config.pruning_threshold,
|
||||
"enable_self_reflexion": db_config.enable_self_reflexion,
|
||||
"iteration_period": db_config.iteration_period,
|
||||
"reflexion_range": db_config.reflexion_range,
|
||||
"baseline": db_config.baseline,
|
||||
}
|
||||
|
||||
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取遗忘配置,通过主键查询某条配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 遗忘配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
result = {
|
||||
"lambda_time": db_config.lambda_time,
|
||||
"lambda_mem": db_config.lambda_mem,
|
||||
"offset": db_config.offset,
|
||||
}
|
||||
|
||||
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
|
||||
"""根据ID获取数据配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
else:
|
||||
db_logger.debug(f"数据配置不存在: config_id={config_id}")
|
||||
return config
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
Returns:
|
||||
List[DataConfig]: 配置列表
|
||||
"""
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
query = db.query(DataConfig)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(DataConfig.workspace_id == workspace_id)
|
||||
|
||||
configs = query.order_by(desc(DataConfig.updated_at)).all()
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, config_id: int) -> bool:
|
||||
"""删除数据配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,配置不存在返回False
|
||||
"""
|
||||
db_logger.debug(f"删除数据配置: config_id={config_id}")
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={config_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_config)
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"数据配置删除成功: config_id={config_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
query = (
|
||||
f"SELECT llm_id, embedding_id, rerank_id, "
|
||||
f"enable_llm_dedup_blockwise, enable_llm_disambiguation, deep_retrieval, "
|
||||
f"t_type_strict, t_name_strict, t_overall, chunker_strategy, "
|
||||
f"statement_granularity, include_dialogue_context, max_context, "
|
||||
f"pruning_enabled, pruning_scene, pruning_threshold, "
|
||||
f"enable_self_reflexion, iteration_period, reflexion_range, baseline "
|
||||
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_forget(key: ConfigKey) -> Tuple[str, Dict]:
|
||||
"""构建遗忘配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
key: 配置键模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建遗忘配置查询语句: config_id={key.config_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT lambda_time, lambda_mem, \"offset\" " # 用双引号包裹保留字别名
|
||||
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_all(workspace_id = None) -> Tuple[str, Dict]:
|
||||
"""构建查询所有配置参数的语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(UUID或字符串),用于过滤查询结果
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
if workspace_id:
|
||||
# 将 UUID 转换为字符串以便在 SQL 中使用
|
||||
workspace_id_str = str(workspace_id) if workspace_id else None
|
||||
query = f"SELECT * FROM {TABLE_NAME} WHERE workspace_id = %(workspace_id)s::uuid ORDER BY updated_at DESC NULLS LAST"
|
||||
params = {"workspace_id": workspace_id_str}
|
||||
else:
|
||||
query = f"SELECT * FROM {TABLE_NAME} ORDER BY updated_at DESC NULLS LAST"
|
||||
params = {}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_delete(key: ConfigParamsDelete) -> Tuple[str, Dict]:
|
||||
"""构建删除语句,通过配置ID删除(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
key: 配置删除模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建删除语句: config_id={key.config_id}")
|
||||
|
||||
query = (
|
||||
f"DELETE FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
|
||||
@@ -102,4 +102,40 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]
|
||||
"""根据 end_user_id 查询对应宿主"""
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_end_user_by_id(end_user_id)
|
||||
return end_user
|
||||
return end_user
|
||||
|
||||
def update_end_user_other_name(
|
||||
db: Session,
|
||||
end_user_id: uuid.UUID,
|
||||
other_name: str
|
||||
) -> int:
|
||||
"""
|
||||
通过 end_user_id 更新 end_user 表中的 other_name 字段
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 宿主ID
|
||||
other_name: 要更新的用户名
|
||||
|
||||
Returns:
|
||||
int: 更新的记录数
|
||||
"""
|
||||
try:
|
||||
# 执行更新
|
||||
updated_count = (
|
||||
db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{EndUser.other_name: other_name},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
|
||||
raise
|
||||
@@ -18,14 +18,25 @@ class ModelConfigRepository:
|
||||
"""模型配置Repository"""
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, model_id: uuid.UUID) -> Optional[ModelConfig]:
|
||||
def get_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据ID获取模型配置"""
|
||||
db_logger.debug(f"根据ID查询模型配置: model_id={model_id}")
|
||||
db_logger.debug(f"根据ID查询模型配置: model_id={model_id}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
model = db.query(ModelConfig).options(
|
||||
query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
).filter(ModelConfig.id == model_id).first()
|
||||
).filter(ModelConfig.id == model_id)
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
)
|
||||
)
|
||||
|
||||
model = query.first()
|
||||
|
||||
if model:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name} (ID: {model_id})")
|
||||
@@ -37,12 +48,23 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(db: Session, name: str) -> Optional[ModelConfig]:
|
||||
def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据名称获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}")
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
model = db.query(ModelConfig).filter(ModelConfig.name == name).first()
|
||||
query = db.query(ModelConfig).filter(ModelConfig.name == name)
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
)
|
||||
)
|
||||
|
||||
model = query.first()
|
||||
if model:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name}")
|
||||
return model
|
||||
@@ -51,24 +73,30 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def search_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表
|
||||
|
||||
Args:
|
||||
name: 模型名称关键词(模糊匹配)
|
||||
tenant_id: 租户ID
|
||||
limit: 返回数量上限
|
||||
Returns:
|
||||
模型配置列表
|
||||
"""
|
||||
db_logger.debug(f"按名称模糊查询模型配置: name~{name}, limit={limit}")
|
||||
db_logger.debug(f"按名称模糊查询模型配置: name~{name}, tenant_id={tenant_id}, limit={limit}")
|
||||
try:
|
||||
models = (
|
||||
db.query(ModelConfig)
|
||||
.filter(ModelConfig.name.ilike(f"%{name}%"))
|
||||
.order_by(ModelConfig.name)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
query = db.query(ModelConfig).filter(ModelConfig.name.ilike(f"%{name}%"))
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
)
|
||||
)
|
||||
|
||||
models = query.order_by(ModelConfig.name).limit(limit).all()
|
||||
db_logger.debug(f"模糊查询成功: 返回数量={len(models)}")
|
||||
return models
|
||||
except Exception as e:
|
||||
@@ -76,14 +104,23 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_list(db: Session, query: ModelConfigQuery) -> Tuple[List[ModelConfig], int]:
|
||||
def get_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> Tuple[List[ModelConfig], int]:
|
||||
"""获取模型配置列表"""
|
||||
db_logger.debug(f"查询模型配置列表: {query.dict()}")
|
||||
db_logger.debug(f"查询模型配置列表: {query.dict()}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
filters = []
|
||||
|
||||
# 添加租户过滤(查询本租户的模型或公开模型)
|
||||
if tenant_id:
|
||||
filters.append(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
)
|
||||
)
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
if query.type:
|
||||
filters.append(ModelConfig.type.in_(query.type))
|
||||
@@ -132,15 +169,24 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_type(db: Session, model_type: ModelType, is_active: bool = True) -> List[ModelConfig]:
|
||||
def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
||||
"""根据类型获取模型配置"""
|
||||
db_logger.debug(f"根据类型查询模型配置: type={model_type}, is_active={is_active}")
|
||||
db_logger.debug(f"根据类型查询模型配置: type={model_type}, tenant_id={tenant_id}, is_active={is_active}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
).filter(ModelConfig.type == model_type)
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
)
|
||||
)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active == True)
|
||||
|
||||
@@ -170,14 +216,20 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> Optional[ModelConfig]:
|
||||
def update(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""更新模型配置"""
|
||||
db_logger.debug(f"更新模型配置: model_id={model_id}")
|
||||
db_logger.debug(f"更新模型配置: model_id={model_id}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
|
||||
query = db.query(ModelConfig).filter(ModelConfig.id == model_id)
|
||||
|
||||
# 添加租户过滤(只能更新本租户的模型)
|
||||
if tenant_id:
|
||||
query = query.filter(ModelConfig.tenant_id == tenant_id)
|
||||
|
||||
db_model = query.first()
|
||||
if not db_model:
|
||||
db_logger.warning(f"模型配置不存在: model_id={model_id}")
|
||||
db_logger.warning(f"模型配置不存在或无权限: model_id={model_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
@@ -197,20 +249,27 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, model_id: uuid.UUID) -> bool:
|
||||
def delete(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||||
"""删除模型配置"""
|
||||
db_logger.debug(f"删除模型配置: model_id={model_id}")
|
||||
db_logger.debug(f"删除模型配置: model_id={model_id}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
|
||||
query = db.query(ModelConfig).filter(ModelConfig.id == model_id)
|
||||
|
||||
# 添加租户过滤(只能删除本租户的模型)
|
||||
if tenant_id:
|
||||
query = query.filter(ModelConfig.tenant_id == tenant_id)
|
||||
|
||||
db_model = query.first()
|
||||
if not db_model:
|
||||
db_logger.warning(f"模型配置不存在: model_id={model_id}")
|
||||
db_logger.warning(f"模型配置不存在或无权限: model_id={model_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_model)
|
||||
# 逻辑删除模型配置
|
||||
db_model.is_active = False
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||
db_logger.info(f"模型配置删除成功(逻辑删除): model_id={model_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -350,10 +409,11 @@ class ModelApiKeyRepository:
|
||||
db_logger.warning(f"API Key不存在: api_key_id={api_key_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_api_key)
|
||||
# 逻辑删除 API Key
|
||||
db_api_key.is_active = False
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
|
||||
db_logger.info(f"API Key删除成功(逻辑删除): api_key_id={api_key_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -27,7 +27,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
|
||||
edges: List[dict] = []
|
||||
for chunk in chunks:
|
||||
for stmt in getattr(chunk, "statements", []) or []:
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode()).hexdigest()
|
||||
edge = {
|
||||
"id": stable_edge_id,
|
||||
"source": chunk.id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储基类模块
|
||||
|
||||
本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。
|
||||
@@ -57,9 +56,17 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
CREATE (n:{self.node_label} $props)
|
||||
RETURN n
|
||||
"""
|
||||
# 使用model_dump()获取所有字段,包括aliases
|
||||
props = entity.model_dump()
|
||||
|
||||
# 确保aliases字段存在且为列表(针对ExtractedEntity节点)
|
||||
if hasattr(entity, 'aliases'):
|
||||
if props.get('aliases') is None:
|
||||
props['aliases'] = []
|
||||
|
||||
result = await self.connector.execute_query(
|
||||
query,
|
||||
props=entity.model_dump()
|
||||
props=props
|
||||
)
|
||||
return entity
|
||||
|
||||
@@ -97,10 +104,18 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
SET n += $props
|
||||
RETURN n
|
||||
"""
|
||||
# 使用model_dump()获取所有字段,包括aliases
|
||||
props = entity.model_dump()
|
||||
|
||||
# 确保aliases字段存在且为列表(针对ExtractedEntity节点)
|
||||
if hasattr(entity, 'aliases'):
|
||||
if props.get('aliases') is None:
|
||||
props['aliases'] = []
|
||||
|
||||
await self.connector.execute_query(
|
||||
query,
|
||||
id=entity.id,
|
||||
props=entity.model_dump()
|
||||
props=props
|
||||
)
|
||||
return entity
|
||||
|
||||
@@ -142,7 +157,7 @@ class BaseNeo4jRepository(BaseRepository[T]):
|
||||
... )
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters]
|
||||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
query = f"""
|
||||
|
||||
@@ -85,7 +85,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
|
||||
THEN CASE
|
||||
WHEN e.aliases IS NULL THEN entity.aliases
|
||||
ELSE reduce(acc = [], alias IN (e.aliases + entity.aliases) |
|
||||
CASE WHEN alias IN acc THEN acc ELSE acc + alias END)
|
||||
END
|
||||
ELSE e.aliases END,
|
||||
e.name_embedding = CASE
|
||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||
@@ -682,3 +686,63 @@ SET r.group_id = e.group_id,
|
||||
r.expired_at = e.expired_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
|
||||
# Entity Merge Query
|
||||
MERGE_ENTITIES = """
|
||||
MATCH (canonical:ExtractedEntity {id: $canonical_id})
|
||||
MATCH (losing:ExtractedEntity {id: $losing_id})
|
||||
|
||||
// 更新canonical实体的aliases
|
||||
SET canonical.aliases = $merged_aliases
|
||||
|
||||
// 转移所有从losing出发的关系到canonical
|
||||
WITH canonical, losing
|
||||
OPTIONAL MATCH (losing)-[r]->(target)
|
||||
WHERE NOT (canonical)-[:RELATES_TO]->(target)
|
||||
FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
CREATE (canonical)-[:RELATES_TO {
|
||||
id: rel.id,
|
||||
relation_type: rel.relation_type,
|
||||
relation_value: rel.relation_value,
|
||||
statement: rel.statement,
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
created_at: rel.created_at,
|
||||
expired_at: rel.expired_at
|
||||
}]->(target)
|
||||
)
|
||||
|
||||
// 转移所有指向losing的关系到canonical
|
||||
WITH canonical, losing
|
||||
OPTIONAL MATCH (source)-[r]->(losing)
|
||||
WHERE NOT (source)-[:RELATES_TO]->(canonical)
|
||||
FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
||||
CREATE (source)-[:RELATES_TO {
|
||||
id: rel.id,
|
||||
relation_type: rel.relation_type,
|
||||
relation_value: rel.relation_value,
|
||||
statement: rel.statement,
|
||||
source_statement_id: rel.source_statement_id,
|
||||
valid_at: rel.valid_at,
|
||||
invalid_at: rel.invalid_at,
|
||||
group_id: rel.group_id,
|
||||
user_id: rel.user_id,
|
||||
apply_id: rel.apply_id,
|
||||
run_id: rel.run_id,
|
||||
created_at: rel.created_at,
|
||||
expired_at: rel.expired_at
|
||||
}]->(canonical)
|
||||
)
|
||||
|
||||
// 删除losing实体及其所有关系
|
||||
WITH losing
|
||||
DETACH DELETE losing
|
||||
|
||||
RETURN count(losing) as deleted
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""实体仓储模块
|
||||
|
||||
本模块提供实体节点的数据访问功能。
|
||||
@@ -7,7 +6,7 @@ Classes:
|
||||
EntityRepository: 实体仓储,管理ExtractedEntityNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
@@ -49,9 +48,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
if n.get('expired_at') and isinstance(n.get('expired_at'), str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
# 确保aliases字段存在且为列表
|
||||
if 'aliases' not in n or n['aliases'] is None:
|
||||
n['aliases'] = []
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
@@ -66,274 +69,4 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
"""
|
||||
return await self.find({"entity_type": entity_type}, limit=limit)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据group_id查询实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_name(
|
||||
self,
|
||||
name: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据名称查询实体
|
||||
|
||||
支持模糊匹配(CONTAINS)。
|
||||
|
||||
Args:
|
||||
name: 实体名称
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
where_clause = "n.name CONTAINS $name"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"name": name, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_related_entities(
|
||||
self,
|
||||
entity_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询相关实体
|
||||
|
||||
查询与指定实体有关系的其他实体。
|
||||
|
||||
Args:
|
||||
entity_id: 实体ID
|
||||
relation_type: 可选的关系类型过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 相关实体列表
|
||||
"""
|
||||
if relation_type:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
relation_type=relation_type,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
where_clause = "n.name_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
|
||||
"""根据陈述句ID查询实体
|
||||
|
||||
查询从指定陈述句中提取的所有实体。
|
||||
|
||||
Args:
|
||||
statement_id: 陈述句ID
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"statement_id": statement_id})
|
||||
|
||||
async def find_strong_entities(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询强连接的实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 强连接的实体列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
|
||||
"""统计各类型实体的数量
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 实体类型到数量的映射
|
||||
"""
|
||||
query = """
|
||||
MATCH (n:ExtractedEntity {group_id: $group_id})
|
||||
RETURN n.entity_type as entity_type, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return {r["entity_type"]: r["count"] for r in results}
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据config_id查询实体
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.name_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ Classes:
|
||||
StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
@@ -76,244 +76,3 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"chunk_id": chunk_id})
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
|
||||
"""根据group_id查询陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clause = "n.statement_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def search_by_keyword(
|
||||
self,
|
||||
keyword: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[StatementNode]:
|
||||
"""基于关键词搜索陈述句
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clause = "n.statement CONTAINS $keyword"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"keyword": keyword, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_temporal_range(
|
||||
self,
|
||||
group_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据时间范围查询陈述句
|
||||
|
||||
查询在指定时间范围内有效的陈述句。
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clauses = ["n.group_id = $group_id"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("n.valid_at >= $start_date")
|
||||
params["start_date"] = start_date.isoformat()
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
|
||||
params["end_date"] = end_date.isoformat()
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_strong_statements(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""查询强连接的陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 强连接的陈述句列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据config_id查询陈述句
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.statement_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
@@ -91,12 +91,13 @@ class TenantRepository:
|
||||
return db_tenant
|
||||
|
||||
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除租户"""
|
||||
"""逻辑删除租户"""
|
||||
db_tenant = self.get_tenant_by_id(tenant_id)
|
||||
if not db_tenant:
|
||||
return False
|
||||
|
||||
self.db.delete(db_tenant)
|
||||
# 逻辑删除租户
|
||||
db_tenant.is_active = False
|
||||
return True
|
||||
|
||||
def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]:
|
||||
|
||||
@@ -144,9 +144,10 @@ class UserRepository:
|
||||
db_logger.debug(f"用户不存在: user_id={user_id}")
|
||||
return False
|
||||
|
||||
self.db.delete(user)
|
||||
# 逻辑删除用户
|
||||
user.is_active = False
|
||||
self.db.flush()
|
||||
db_logger.info(f"用户删除成功: {user.username}")
|
||||
db_logger.info(f"用户删除成功(逻辑删除): {user.username}")
|
||||
return True
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除用户失败: user_id={user_id} - {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user