Files
MemoryBear/api/app/repositories/data_config_repository.py

712 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""数据配置Repository模块
本模块提供data_config表的数据访问层使用SQLAlchemy ORM进行数据库操作。
包括CRUD操作和Neo4j Cypher查询常量。
Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作
"""
import uuid
from typing import Dict, List, Optional, Tuple
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from sqlalchemy import desc
from sqlalchemy.orm import Session
# 获取数据库专用日志器
db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "data_config"
class DataConfigRepository:
"""数据配置Repository
提供data_config表的数据访问方法包括
- SQLAlchemy ORM 数据库操作
- Neo4j Cypher查询常量
"""
# ==================== Neo4j Cypher 查询常量 ====================
# Dialogue count by group
SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Chunk count by group
SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Statement count by group
SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# ExtractedEntity count by group
SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# All counts by label and total
SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN n.entity_idx AS entity_idx,
n.connect_strength AS connect_strength,
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
n.fact_summary AS fact_summary,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.id AS id
"""
# Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
r.group_id AS group_id,
r.apply_id AS apply_id,
r.user_id AS user_id,
elementId(r) AS rel_id,
startNode(r).id AS source_id,
endNode(r).id AS target_id,
r.predicate AS predicate,
r.statement_id AS statement_id,
r.statement AS statement
"""
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
{
entity_idx: n.entity_idx,
connect_strength: n.connect_strength,
description: n.description,
entity_type: n.entity_type,
name: n.name,
fact_summary: n.fact_summary,
id: n.id
} AS sourceNode,
{
rel_id: elementId(r),
source_id: startNode(r).id,
target_id: endNode(r).id,
predicate: r.predicate,
statement_id: r.statement_id,
statement: r.statement
} AS edge,
{
entity_idx: m.entity_idx,
connect_strength: m.connect_strength,
description: m.description,
entity_type: m.entity_type,
name: m.name,
fact_summary: m.fact_summary,
id: m.id
} AS targetNode
"""
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
@staticmethod
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
"""构建反思配置更新语句SQLAlchemy text() 命名参数)
Args:
config_id: 配置ID
**kwargs: 反思配置参数
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
key_where = "config_id = :config_id"
set_fields: List[str] = []
params: Dict = {
"config_id": config_id,
}
# 反思配置字段映射
mapping = {
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
"reflection_model_id": "reflection_model_id",
"memory_verify": "memory_verify",
"quality_assessment": "quality_assessment",
}
for api_field, db_col in mapping.items():
if api_field in kwargs and kwargs[api_field] is not None:
set_fields.append(f"{db_col} = :{api_field}")
params[api_field] = kwargs[api_field]
if not set_fields:
raise ValueError("No fields to update")
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
@staticmethod
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
config_id: 配置ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
query = (
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
f"reflection_model_id, memory_verify, quality_assessment, user_id "
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
)
params = {"config_id": config_id}
return query, params
@staticmethod
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args:
workspace_id: 工作空间ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
query = (
f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at "
f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC"
)
params = {"workspace_id": workspace_id}
return query, params
@staticmethod
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
"""创建数据配置
Args:
db: 数据库会话
params: 配置参数创建模型
Returns:
DataConfig: 创建的配置对象
"""
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
try:
db_config = DataConfig(
config_name=params.config_name,
config_desc=params.config_desc,
workspace_id=params.workspace_id,
llm_id=params.llm_id,
embedding_id=params.embedding_id,
rerank_id=params.rerank_id,
)
db.add(db_config)
db.flush() # 获取自增ID但不提交事务
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
raise
@staticmethod
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
"""更新基础配置
Args:
db: 数据库会话
update: 配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段
has_update = False
if update.config_name is not None:
db_config.config_name = update.config_name
has_update = True
if update.config_desc is not None:
db_config.config_desc = update.config_desc
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
"""更新记忆萃取引擎配置
Args:
db: 数据库会话
update: 萃取配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段映射
field_mapping = {
# 模型选择
"llm_id": "llm",
"embedding_id": "embedding_id",
"rerank_id": "rerank_id",
# 记忆萃取引擎
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
"enable_llm_disambiguation": "enable_llm_disambiguation",
"deep_retrieval": "deep_retrieval",
"t_type_strict": "t_type_strict",
"t_name_strict": "t_name_strict",
"t_overall": "t_overall",
"state": "state",
"chunker_strategy": "chunker_strategy",
# 句子提取
"statement_granularity": "statement_granularity",
"include_dialogue_context": "include_dialogue_context",
"max_context": "max_context",
# 剪枝配置
"pruning_enabled": "pruning_enabled",
"pruning_scene": "pruning_scene",
"pruning_threshold": "pruning_threshold",
# 自我反思配置
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
}
has_update = False
for api_field, db_field in field_mapping.items():
value = getattr(update, api_field, None)
if value is not None:
setattr(db_config, db_field, value)
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
"""更新遗忘引擎配置
Args:
db: 数据库会话
update: 遗忘配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段
has_update = False
if update.lambda_time is not None:
db_config.lambda_time = update.lambda_time
has_update = True
if update.lambda_mem is not None:
db_config.lambda_mem = update.lambda_mem
has_update = True
if update.offset is not None:
db_config.offset = update.offset
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[Dict]: 萃取配置字典不存在则返回None
"""
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
return None
result = {
"llm_id": db_config.llm_id,
"embedding_id": db_config.embedding_id,
"rerank_id": db_config.rerank_id,
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
"deep_retrieval": db_config.deep_retrieval,
"t_type_strict": db_config.t_type_strict,
"t_name_strict": db_config.t_name_strict,
"t_overall": db_config.t_overall,
"chunker_strategy": db_config.chunker_strategy,
"statement_granularity": db_config.statement_granularity,
"include_dialogue_context": db_config.include_dialogue_context,
"max_context": db_config.max_context,
"pruning_enabled": db_config.pruning_enabled,
"pruning_scene": db_config.pruning_scene,
"pruning_threshold": db_config.pruning_threshold,
"enable_self_reflexion": db_config.enable_self_reflexion,
"iteration_period": db_config.iteration_period,
"reflexion_range": db_config.reflexion_range,
"baseline": db_config.baseline,
}
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
return result
except Exception as e:
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[Dict]: 遗忘配置字典不存在则返回None
"""
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
return None
result = {
"lambda_time": db_config.lambda_time,
"lambda_mem": db_config.lambda_mem,
"offset": db_config.offset,
}
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
return result
except Exception as e:
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
"""根据ID获取数据配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[DataConfig]: 配置对象不存在则返回None
"""
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
try:
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config:
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
else:
db_logger.debug(f"数据配置不存在: config_id={config_id}")
return config
except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
"""Get data config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
"""
import time
from app.models.workspace_model import Workspace
start_time = time.time()
# Log configuration loading start
config_logger.info(
"Loading configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id
}
)
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config_only:
if config_only.workspace_id is None:
config_logger.error(
"Configuration has no associated workspace ID",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": None,
"load_result": "no_workspace_id",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace")
else:
config_logger.error(
"Configuration references non-existent workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": str(config_only.workspace_id),
"load_result": "workspace_not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
"Configuration not found",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config not found: config_id={config_id}")
return None
config, workspace = result
# Log successful configuration loading
config_logger.info(
"Configuration with workspace loaded successfully",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"config_name": config.config_name,
"workspace_id": str(workspace.id),
"workspace_name": workspace.name,
"tenant_id": str(workspace.tenant_id),
"load_result": "success",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
# Re-raise known business exceptions
raise
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms
},
exc_info=True
)
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
"""获取所有配置参数
Args:
db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[DataConfig]: 配置列表
"""
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(DataConfig)
if workspace_id:
query = query.filter(DataConfig.workspace_id == workspace_id)
configs = query.order_by(desc(DataConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
raise
@staticmethod
def delete(db: Session, config_id: int) -> bool:
"""删除数据配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
bool: 删除成功返回True配置不存在返回False
"""
db_logger.debug(f"删除数据配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={config_id}")
return False
db.delete(db_config)
db.commit()
db_logger.info(f"数据配置删除成功: config_id={config_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
raise