feat: Add base project structure with API and web components
This commit is contained in:
171
api/app/repositories/__init__.py
Normal file
171
api/app/repositories/__init__.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""仓储模块
|
||||
|
||||
本模块提供统一的数据访问层,包括PostgreSQL和Neo4j的仓储实现。
|
||||
|
||||
Classes:
|
||||
RepositoryFactory: 仓储工厂,统一管理所有数据库的仓储实例
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.dialog_repository import DialogRepository
|
||||
from app.repositories.neo4j.statement_repository import StatementRepository
|
||||
from app.repositories.neo4j.entity_repository import EntityRepository
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.repositories.app_repository import AppRepository
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
"""仓储工厂 - 统一管理所有数据库的仓储
|
||||
|
||||
这个工厂类提供了获取各种仓储实例的统一接口。
|
||||
支持Neo4j图数据库和PostgreSQL关系数据库的仓储。
|
||||
|
||||
Attributes:
|
||||
neo4j_connector: Neo4j连接器实例(可选)
|
||||
db_session: SQLAlchemy数据库会话(可选)
|
||||
|
||||
Example:
|
||||
>>> # 创建工厂实例
|
||||
>>> factory = RepositoryFactory(
|
||||
... neo4j_connector=Neo4jConnector(),
|
||||
... db_session=db_session
|
||||
... )
|
||||
>>>
|
||||
>>> # 获取Neo4j仓储
|
||||
>>> dialog_repo = factory.get_dialog_repository()
|
||||
>>> statement_repo = factory.get_statement_repository()
|
||||
>>>
|
||||
>>> # 获取PostgreSQL仓储
|
||||
>>> knowledge_repo = factory.get_knowledge_repository()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
neo4j_connector: Optional[Neo4jConnector] = None,
|
||||
db_session: Optional[Session] = None
|
||||
):
|
||||
"""初始化仓储工厂
|
||||
|
||||
Args:
|
||||
neo4j_connector: Neo4j连接器实例(可选)
|
||||
db_session: SQLAlchemy数据库会话(可选)
|
||||
"""
|
||||
self.neo4j_connector = neo4j_connector
|
||||
self.db_session = db_session
|
||||
|
||||
# ==================== Neo4j 仓储 ====================
|
||||
|
||||
def get_dialog_repository(self) -> DialogRepository:
|
||||
"""获取对话仓储
|
||||
|
||||
Returns:
|
||||
DialogRepository: 对话仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果Neo4j连接器未初始化
|
||||
"""
|
||||
if not self.neo4j_connector:
|
||||
raise ValueError("Neo4j connector not initialized")
|
||||
return DialogRepository(self.neo4j_connector)
|
||||
|
||||
def get_statement_repository(self) -> StatementRepository:
|
||||
"""获取陈述句仓储
|
||||
|
||||
Returns:
|
||||
StatementRepository: 陈述句仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果Neo4j连接器未初始化
|
||||
"""
|
||||
if not self.neo4j_connector:
|
||||
raise ValueError("Neo4j connector not initialized")
|
||||
return StatementRepository(self.neo4j_connector)
|
||||
|
||||
def get_entity_repository(self) -> EntityRepository:
|
||||
"""获取实体仓储
|
||||
|
||||
Returns:
|
||||
EntityRepository: 实体仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果Neo4j连接器未初始化
|
||||
"""
|
||||
if not self.neo4j_connector:
|
||||
raise ValueError("Neo4j connector not initialized")
|
||||
return EntityRepository(self.neo4j_connector)
|
||||
|
||||
# ==================== PostgreSQL 仓储 ====================
|
||||
# 注意:现有的PostgreSQL仓储保持不变,这里只是提供统一的访问接口
|
||||
# 部分仓储(如knowledge_repository、document_repository)使用函数式接口
|
||||
# 部分仓储(如user_repository、workspace_repository)使用类接口
|
||||
|
||||
def get_user_repository(self) -> UserRepository:
|
||||
"""获取用户仓储
|
||||
|
||||
Returns:
|
||||
UserRepository: 用户仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果数据库会话未初始化
|
||||
"""
|
||||
if not self.db_session:
|
||||
raise ValueError("Database session not initialized")
|
||||
return UserRepository(self.db_session)
|
||||
|
||||
def get_workspace_repository(self) -> WorkspaceRepository:
|
||||
"""获取工作空间仓储
|
||||
|
||||
Returns:
|
||||
WorkspaceRepository: 工作空间仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果数据库会话未初始化
|
||||
"""
|
||||
if not self.db_session:
|
||||
raise ValueError("Database session not initialized")
|
||||
return WorkspaceRepository(self.db_session)
|
||||
|
||||
def get_app_repository(self) -> AppRepository:
|
||||
"""获取应用仓储
|
||||
|
||||
Returns:
|
||||
AppRepository: 应用仓储实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果数据库会话未初始化
|
||||
"""
|
||||
if not self.db_session:
|
||||
raise ValueError("Database session not initialized")
|
||||
return AppRepository(self.db_session)
|
||||
|
||||
def get_db_session(self) -> Session:
|
||||
"""获取数据库会话
|
||||
|
||||
用于访问函数式仓储(如knowledge_repository、document_repository)
|
||||
|
||||
Returns:
|
||||
Session: SQLAlchemy数据库会话
|
||||
|
||||
Raises:
|
||||
ValueError: 如果数据库会话未初始化
|
||||
|
||||
Example:
|
||||
>>> factory = RepositoryFactory(db_session=session)
|
||||
>>> db = factory.get_db_session()
|
||||
>>> # 使用函数式仓储
|
||||
>>> from app.repositories import knowledge_repository
|
||||
>>> knowledges = knowledge_repository.get_knowledges_paginated(db, [], 1, 10)
|
||||
"""
|
||||
if not self.db_session:
|
||||
raise ValueError("Database session not initialized")
|
||||
return self.db_session
|
||||
|
||||
|
||||
__all__ = [
|
||||
'RepositoryFactory',
|
||||
]
|
||||
138
api/app/repositories/api_key_repository.py
Normal file
138
api/app/repositories/api_key_repository.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""API Key Repository"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, func, and_
|
||||
from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.schemas import api_key_schema
|
||||
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""API Key 数据访问层"""
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, api_key_data: dict) -> ApiKey:
|
||||
"""创建 API Key"""
|
||||
api_key = ApiKey(**api_key_data)
|
||||
db.add(api_key)
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ApiKey]:
|
||||
"""根据 ID 获取 API Key"""
|
||||
return db.get(ApiKey, api_key_id)
|
||||
|
||||
@staticmethod
|
||||
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
||||
"""根据哈希值获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
||||
return db.scalars(stmt).first()
|
||||
|
||||
@staticmethod
|
||||
def list_by_workspace(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
) -> Tuple[List[ApiKey], int]:
|
||||
"""列出工作空间的 API Keys"""
|
||||
stmt = select(ApiKey).where(ApiKey.workspace_id == workspace_id)
|
||||
|
||||
# 过滤条件
|
||||
if query.type:
|
||||
stmt = stmt.where(ApiKey.type == query.type)
|
||||
if query.is_active is not None:
|
||||
stmt = stmt.where(ApiKey.is_active == query.is_active)
|
||||
if query.resource_id:
|
||||
stmt = stmt.where(ApiKey.resource_id == query.resource_id)
|
||||
|
||||
# 总数
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = db.execute(count_stmt).scalar()
|
||||
|
||||
# 分页
|
||||
stmt = stmt.order_by(ApiKey.created_at.desc())
|
||||
stmt = stmt.offset((query.page - 1) * query.pagesize).limit(query.pagesize)
|
||||
|
||||
items = db.scalars(stmt).all()
|
||||
return list(items), total
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey:
|
||||
"""更新 API Key"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
for key, value in update_data.items():
|
||||
if value is not None:
|
||||
setattr(api_key, key, value)
|
||||
api_key.updated_at = datetime.datetime.now()
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""删除 API Key"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
db.delete(api_key)
|
||||
db.flush()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""更新使用统计"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
api_key.usage_count += 1
|
||||
api_key.quota_used += 1
|
||||
api_key.last_used_at = datetime.datetime.now()
|
||||
db.flush()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_stats(db: Session, api_key_id: uuid.UUID) -> dict:
|
||||
"""获取使用统计"""
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
# 今日请求数
|
||||
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_count_stmt = select(func.count()).select_from(ApiKeyLog).where(
|
||||
and_(
|
||||
ApiKeyLog.api_key_id == api_key_id,
|
||||
ApiKeyLog.created_at >= today_start
|
||||
)
|
||||
)
|
||||
requests_today = db.execute(today_count_stmt).scalar() or 0
|
||||
|
||||
# 平均响应时间
|
||||
avg_time_stmt = select(func.avg(ApiKeyLog.response_time)).where(
|
||||
ApiKeyLog.api_key_id == api_key_id
|
||||
)
|
||||
avg_response_time = db.execute(avg_time_stmt).scalar()
|
||||
|
||||
return {
|
||||
"total_requests": api_key.usage_count,
|
||||
"requests_today": requests_today,
|
||||
"quota_used": api_key.quota_used,
|
||||
"quota_limit": api_key.quota_limit,
|
||||
"last_used_at": api_key.last_used_at,
|
||||
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
||||
}
|
||||
|
||||
|
||||
class ApiKeyLogRepository:
|
||||
"""API Key 日志数据访问层"""
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, log_data: dict) -> ApiKeyLog:
|
||||
"""创建日志"""
|
||||
log = ApiKeyLog(**log_data)
|
||||
db.add(log)
|
||||
db.flush()
|
||||
return log
|
||||
30
api/app/repositories/app_repository.py
Normal file
30
api/app/repositories/app_repository.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.models.app_model import App
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class AppRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
try:
|
||||
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用")
|
||||
return apps
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_workspace_id(workspace_id)
|
||||
108
api/app/repositories/base_repository.py
Normal file
108
api/app/repositories/base_repository.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""基础仓储接口模块
|
||||
|
||||
本模块定义了通用的仓储接口,适用于所有数据库类型(PostgreSQL、Neo4j等)。
|
||||
遵循仓储模式(Repository Pattern),提供统一的数据访问抽象。
|
||||
|
||||
Classes:
|
||||
BaseRepository: 基础仓储接口,定义CRUD操作的抽象方法
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar, List, Optional, Dict, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseRepository(ABC, Generic[T]):
|
||||
"""基础仓储接口 - 适用于所有数据库类型
|
||||
|
||||
这是一个抽象基类,定义了所有仓储必须实现的基本CRUD操作。
|
||||
使用泛型T来支持不同的实体类型。
|
||||
|
||||
Type Parameters:
|
||||
T: 实体类型,通常是Pydantic模型或ORM模型
|
||||
|
||||
Methods:
|
||||
create: 创建新实体
|
||||
get_by_id: 根据ID获取实体
|
||||
update: 更新现有实体
|
||||
delete: 删除实体
|
||||
find: 根据条件查询实体列表
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create(self, entity: T) -> T:
|
||||
"""创建实体
|
||||
|
||||
Args:
|
||||
entity: 要创建的实体对象
|
||||
|
||||
Returns:
|
||||
T: 创建后的实体对象(可能包含生成的ID等)
|
||||
|
||||
Raises:
|
||||
Exception: 创建失败时抛出异常
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, entity_id: str) -> Optional[T]:
|
||||
"""根据ID获取实体
|
||||
|
||||
Args:
|
||||
entity_id: 实体的唯一标识符
|
||||
|
||||
Returns:
|
||||
Optional[T]: 找到的实体对象,如果不存在则返回None
|
||||
|
||||
Raises:
|
||||
Exception: 查询失败时抛出异常
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, entity: T) -> T:
|
||||
"""更新实体
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体对象(必须包含ID)
|
||||
|
||||
Returns:
|
||||
T: 更新后的实体对象
|
||||
|
||||
Raises:
|
||||
Exception: 更新失败时抛出异常
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, entity_id: str) -> bool:
|
||||
"""删除实体
|
||||
|
||||
Args:
|
||||
entity_id: 要删除的实体ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
|
||||
Raises:
|
||||
Exception: 删除失败时抛出异常
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
|
||||
"""查询实体列表
|
||||
|
||||
Args:
|
||||
filters: 查询条件字典,键为字段名,值为期望的值
|
||||
limit: 返回结果的最大数量,默认100
|
||||
|
||||
Returns:
|
||||
List[T]: 符合条件的实体列表
|
||||
|
||||
Raises:
|
||||
Exception: 查询失败时抛出异常
|
||||
"""
|
||||
pass
|
||||
408
api/app/repositories/data_config_repository.py
Normal file
408
api/app/repositories/data_config_repository.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据配置Repository模块
|
||||
|
||||
本模块提供data_config表的数据访问层,包括SQL查询构建和Neo4j Cypher查询。
|
||||
从 app.core.memory.src.data_config_api.sql_queries 迁移而来。
|
||||
|
||||
Classes:
|
||||
DataConfigRepository: 数据配置仓储类,提供CRUD操作和查询构建
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple, List
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
# 表名常量
|
||||
TABLE_NAME = "data_config"
|
||||
|
||||
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
- SQL查询构建(PostgreSQL)
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
|
||||
# ==================== Neo4j Cypher 查询常量 ====================
|
||||
|
||||
# Dialogue count by group
|
||||
SEARCH_FOR_DIALOGUE = """
|
||||
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# Chunk count by group
|
||||
SEARCH_FOR_CHUNK = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# Statement count by group
|
||||
SEARCH_FOR_STATEMENT = """
|
||||
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# ExtractedEntity count by group
|
||||
SEARCH_FOR_ENTITY = """
|
||||
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
# All counts by label and total
|
||||
SEARCH_FOR_ALL = """
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN n.entity_idx AS entity_idx,
|
||||
n.connect_strength AS connect_strength,
|
||||
n.description AS description,
|
||||
n.entity_type AS entity_type,
|
||||
n.name AS name,
|
||||
n.fact_summary AS fact_summary,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.id AS id
|
||||
"""
|
||||
|
||||
# Edges between extracted entities within group/app/user
|
||||
SEARCH_FOR_EDGES = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN
|
||||
r.group_id AS group_id,
|
||||
r.apply_id AS apply_id,
|
||||
r.user_id AS user_id,
|
||||
elementId(r) AS rel_id,
|
||||
startNode(r).id AS source_id,
|
||||
endNode(r).id AS target_id,
|
||||
r.predicate AS predicate,
|
||||
r.statement_id AS statement_id,
|
||||
r.statement AS statement
|
||||
"""
|
||||
|
||||
# Entity graph within group (source node, edge, target node)
|
||||
SEARCH_FOR_ENTITY_GRAPH = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN
|
||||
{
|
||||
entity_idx: n.entity_idx,
|
||||
connect_strength: n.connect_strength,
|
||||
description: n.description,
|
||||
entity_type: n.entity_type,
|
||||
name: n.name,
|
||||
fact_summary: n.fact_summary,
|
||||
id: n.id
|
||||
} AS sourceNode,
|
||||
{
|
||||
rel_id: elementId(r),
|
||||
source_id: startNode(r).id,
|
||||
target_id: endNode(r).id,
|
||||
predicate: r.predicate,
|
||||
statement_id: r.statement_id,
|
||||
statement: r.statement
|
||||
} AS edge,
|
||||
{
|
||||
entity_idx: m.entity_idx,
|
||||
connect_strength: m.connect_strength,
|
||||
description: m.description,
|
||||
entity_type: m.entity_type,
|
||||
name: m.name,
|
||||
fact_summary: m.fact_summary,
|
||||
id: m.id
|
||||
} AS targetNode
|
||||
"""
|
||||
|
||||
# ==================== SQL 查询构建方法 ====================
|
||||
|
||||
@staticmethod
|
||||
def build_insert(params: ConfigParamsCreate) -> Tuple[str, Dict]:
|
||||
"""构建插入语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
params: 配置参数创建模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建插入语句: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
columns = [
|
||||
"config_name",
|
||||
"config_desc",
|
||||
"workspace_id",
|
||||
"llm_id",
|
||||
"embedding_id",
|
||||
"rerank_id",
|
||||
"created_at",
|
||||
]
|
||||
placeholders = [
|
||||
"%(config_name)s",
|
||||
"%(config_desc)s",
|
||||
"%(workspace_id)s::uuid",
|
||||
"%(llm_id)s",
|
||||
"%(embedding_id)s",
|
||||
"%(rerank_id)s",
|
||||
"timezone('Asia/Shanghai', now())",
|
||||
]
|
||||
query = f"INSERT INTO {TABLE_NAME} (" + ",".join(columns) + ") VALUES (" + ",".join(placeholders) + ")"
|
||||
# 将 UUID 转换为字符串
|
||||
workspace_id_str = str(params.workspace_id) if params.workspace_id else None
|
||||
params_dict = {
|
||||
"config_name": params.config_name,
|
||||
"config_desc": params.config_desc,
|
||||
"workspace_id": workspace_id_str,
|
||||
"llm_id": params.llm_id,
|
||||
"embedding_id": params.embedding_id,
|
||||
"rerank_id": params.rerank_id,
|
||||
}
|
||||
return query, params_dict
|
||||
|
||||
@staticmethod
|
||||
def build_update(update: ConfigUpdate) -> Tuple[str, Dict]:
|
||||
"""构建基础配置更新语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
update: 配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建更新语句: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
"config_name": "config_name",
|
||||
"config_desc": "config_desc",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
|
||||
@staticmethod
|
||||
def build_update_extracted(update: ConfigUpdateExtracted) -> Tuple[str, Dict]:
|
||||
"""构建记忆萃取引擎配置更新语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
update: 萃取配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建萃取配置更新语句: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm",
|
||||
"embedding_id": "embedding",
|
||||
"rerank_id": "rerank",
|
||||
# 记忆萃取引擎
|
||||
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
|
||||
"enable_llm_disambiguation": "enable_llm_disambiguation",
|
||||
"deep_retrieval": "deep_retrieval",
|
||||
"t_type_strict": "t_type_strict",
|
||||
"t_name_strict": "t_name_strict",
|
||||
"t_overall": "t_overall",
|
||||
"state": "state",
|
||||
"chunker_strategy": "chunker_strategy",
|
||||
# 句子提取
|
||||
"statement_granularity": "statement_granularity",
|
||||
"include_dialogue_context": "include_dialogue_context",
|
||||
"max_context": "max_context",
|
||||
# 剪枝配置
|
||||
"pruning_enabled": "pruning_enabled",
|
||||
"pruning_scene": "pruning_scene",
|
||||
"pruning_threshold": "pruning_threshold",
|
||||
# 自我反思配置
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_update_forget(update: ConfigUpdateForget) -> Tuple[str, Dict]:
|
||||
"""构建遗忘引擎配置更新语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建遗忘配置更新语句: config_id={update.config_id}")
|
||||
|
||||
key_where = "config_id = %(config_id)s"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": update.config_id,
|
||||
}
|
||||
|
||||
mapping = {
|
||||
# 遗忘引擎
|
||||
"lambda_time": "lambda_time",
|
||||
"lambda_mem": "lambda_mem",
|
||||
# 由于 PostgreSQL 中 OFFSET 是保留字,需使用双引号包裹列名
|
||||
"offset": '"offset"',
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
value = getattr(update, api_field)
|
||||
if value is not None:
|
||||
set_fields.append(f"{db_col} = %({api_field})s")
|
||||
params[api_field] = value
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_extracted(key: ConfigKey) -> Tuple[str, Dict]:
|
||||
"""构建萃取配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
key: 配置键模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建萃取配置查询语句: config_id={key.config_id}")
|
||||
# f"SELECT statement_granularity, include_dialogue_context, max_context, "
|
||||
|
||||
query = (
|
||||
f"SELECT llm_id, embedding_id, rerank_id, "
|
||||
f"enable_llm_dedup_blockwise, enable_llm_disambiguation, deep_retrieval, "
|
||||
f"t_type_strict, t_name_strict, t_overall, chunker_strategy, "
|
||||
f"statement_granularity, include_dialogue_context, max_context, "
|
||||
f"pruning_enabled, pruning_scene, pruning_threshold, "
|
||||
f"enable_self_reflexion, iteration_period, reflexion_range, baseline "
|
||||
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_forget(key: ConfigKey) -> Tuple[str, Dict]:
|
||||
"""构建遗忘配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
key: 配置键模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建遗忘配置查询语句: config_id={key.config_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT lambda_time, lambda_mem, \"offset\" " # 用双引号包裹保留字别名
|
||||
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_all(workspace_id = None) -> Tuple[str, Dict]:
|
||||
"""构建查询所有配置参数的语句(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(UUID或字符串),用于过滤查询结果
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
if workspace_id:
|
||||
# 将 UUID 转换为字符串以便在 SQL 中使用
|
||||
workspace_id_str = str(workspace_id) if workspace_id else None
|
||||
query = f"SELECT * FROM {TABLE_NAME} WHERE workspace_id = %(workspace_id)s::uuid ORDER BY updated_at DESC NULLS LAST"
|
||||
params = {"workspace_id": workspace_id_str}
|
||||
else:
|
||||
query = f"SELECT * FROM {TABLE_NAME} ORDER BY updated_at DESC NULLS LAST"
|
||||
params = {}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_delete(key: ConfigParamsDelete) -> Tuple[str, Dict]:
|
||||
"""构建删除语句,通过配置ID删除(PostgreSQL 命名参数)
|
||||
|
||||
Args:
|
||||
key: 配置删除模型
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建删除语句: config_id={key.config_id}")
|
||||
|
||||
query = (
|
||||
f"DELETE FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
|
||||
)
|
||||
params = {"config_id": key.config_id}
|
||||
return query, params
|
||||
153
api/app/repositories/document_repository.py
Normal file
153
api/app/repositories/document_repository.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.document_model import Document
|
||||
from app.schemas import document_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_documents_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query document (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(f"Query documents in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(Document)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of document queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(Document, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(f"The document paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [document_schema.Document.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying document pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_document(db: Session, document: document_schema.DocumentCreate) -> Document:
|
||||
db_logger.debug(f"Create a document record: file_name={document.file_name}")
|
||||
|
||||
try:
|
||||
db_document = Document(**document.model_dump())
|
||||
db.add(db_document)
|
||||
db.commit()
|
||||
db_logger.info(f"Document record created successfully: {document.file_name} (ID: {db_document.id})")
|
||||
return db_document
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a document record: title={document.file_name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_document_by_id(db: Session, document_id: uuid.UUID) -> Document | None:
|
||||
db_logger.debug(f"Query documents based on ID: document_id={document_id}")
|
||||
|
||||
try:
|
||||
document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if document:
|
||||
db_logger.debug(f"Document query successful: {document.file_name} (ID: {document_id})")
|
||||
else:
|
||||
db_logger.debug(f"Document does not exist: document_id={document_id}")
|
||||
return document
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID) -> int:
|
||||
"""
|
||||
Reset the processing progress of all documents under the specified knowledge base
|
||||
|
||||
Args:
|
||||
db: database session
|
||||
kb_id: Knowledge Base ID
|
||||
|
||||
Returns:
|
||||
int: Number of updated documents
|
||||
"""
|
||||
db_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id}")
|
||||
try:
|
||||
# Build update conditions
|
||||
filters = [
|
||||
Document.kb_id == kb_id
|
||||
]
|
||||
|
||||
# Build updated data
|
||||
update_data = {
|
||||
Document.chunk_num: 0,
|
||||
Document.progress: 0,
|
||||
Document.progress_msg: "Pending",
|
||||
Document.process_duration: 0,
|
||||
Document.run: 0, # Reset run status
|
||||
Document.updated_at: datetime.datetime.now()
|
||||
}
|
||||
|
||||
# Perform batch update
|
||||
result = db.query(Document).filter(*filters).update(
|
||||
update_data,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# commit transaction
|
||||
db.commit()
|
||||
db_logger.debug(f"Successfully reset the processing progress of all documents under the specified knowledge base: kb_id: {kb_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"Failed to reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def delete_document_by_id(db: Session, document_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete document record: document_id={document_id}")
|
||||
|
||||
try:
|
||||
# First, query the document information for logging purposes
|
||||
document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if document:
|
||||
file_name = document.file_name
|
||||
else:
|
||||
file_name = "unknown"
|
||||
|
||||
result = db.query(Document).filter(Document.id == document_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"Document record deleted successfully: {file_name} (ID: {document_id})")
|
||||
else:
|
||||
db_logger.warning(f"The document record does not exist, and cannot be deleted: document_id={document_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete document record: document_id={document_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
105
api/app/repositories/end_user_repository.py
Normal file
105
api/app/repositories/end_user_repository.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class EndUserRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_end_users_by_app_id(self, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主"""
|
||||
try:
|
||||
end_users = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.app_id == app_id)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询应用 {app_id} 下的 {len(end_users)} 个宿主")
|
||||
return end_users
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据 end_user_id 查询宿主"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.first()
|
||||
)
|
||||
if end_user:
|
||||
db_logger.info(f"成功查询到宿主 {end_user_id}")
|
||||
else:
|
||||
db_logger.info(f"未找到宿主 {end_user_id}")
|
||||
return end_user
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_or_create_end_user(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
other_id: str,
|
||||
original_user_id: Optional[str] = None
|
||||
) -> EndUser:
|
||||
"""获取或创建终端用户
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
other_id: 第三方ID
|
||||
original_user_id: 原始用户ID (存储到 other_id)
|
||||
"""
|
||||
try:
|
||||
# 尝试查找现有用户
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.other_id == other_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user:
|
||||
db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}")
|
||||
return end_user
|
||||
|
||||
# 创建新用户
|
||||
end_user = EndUser(
|
||||
app_id=app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
self.db.add(end_user)
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
|
||||
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}")
|
||||
return end_user
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
repo = EndUserRepository(db)
|
||||
end_users = repo.get_end_users_by_app_id(app_id)
|
||||
return end_users
|
||||
|
||||
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据 end_user_id 查询对应宿主"""
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_end_user_by_id(end_user_id)
|
||||
return end_user
|
||||
121
api/app/repositories/file_repository.py
Normal file
121
api/app/repositories/file_repository.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.file_model import File
|
||||
from app.schemas import file_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_files_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query file (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(f"Query file in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(File)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of file queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(File, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(f"The file paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [file_schema.File.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying file pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_file(db: Session, file: file_schema.FileCreate) -> File:
|
||||
db_logger.debug(f"Create a file record: filename={file.file_name}")
|
||||
|
||||
try:
|
||||
db_file = File(**file.model_dump())
|
||||
db.add(db_file)
|
||||
db.commit()
|
||||
db_logger.info(f"File record created successfully: {file.file_name} (ID: {db_file.id})")
|
||||
return db_file
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a file record: filename={file.file_name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None:
|
||||
db_logger.debug(f"Query file based on ID: file_id={file_id}")
|
||||
|
||||
try:
|
||||
file = db.query(File).filter(File.id == file_id).first()
|
||||
if file:
|
||||
db_logger.debug(f"File query successful: {file.file_name} (ID: {file_id})")
|
||||
else:
|
||||
db_logger.debug(f"File does not exist: file_id={file_id}")
|
||||
return file
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None) -> list | None:
|
||||
db_logger.debug(f"Query file based on folder ID: parent_id={parent_id}")
|
||||
|
||||
try:
|
||||
query = db.query(File)
|
||||
if parent_id:
|
||||
query = query.filter(File.parent_id == parent_id)
|
||||
files = query.all()
|
||||
db_logger.debug(f"Folder query file successful: parent_id={parent_id}, file_num={len(files)}")
|
||||
return files
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query files based on folder ID: parent_id={parent_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_file_by_id(db: Session, file_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete file record: file_id={file_id}")
|
||||
|
||||
try:
|
||||
# First, query the file information for logging purposes
|
||||
file = db.query(File).filter(File.id == file_id).first()
|
||||
if file:
|
||||
filename = file.file_name
|
||||
else:
|
||||
filename = "unknown"
|
||||
|
||||
result = db.query(File).filter(File.id == file_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"File record deleted successfully: {filename} (ID: {file_id})")
|
||||
else:
|
||||
db_logger.warning(f"The file record does not exist, and cannot be deleted: file_id={file_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete file record: file_id={file_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
243
api/app/repositories/generic_file_repository.py
Normal file
243
api/app/repositories/generic_file_repository.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Generic File Repository
|
||||
Handles database operations for generic file uploads.
|
||||
"""
|
||||
import uuid
|
||||
from typing import Optional, List, Tuple, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func
|
||||
|
||||
from app.models.generic_file_model import GenericFile
|
||||
from app.core.upload_enums import UploadContext
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Get database logger
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class GenericFileRepository:
|
||||
"""Repository for generic file operations"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_file(self, file_data: Dict[str, Any]) -> GenericFile:
|
||||
"""
|
||||
Create a new file record in the database.
|
||||
|
||||
Args:
|
||||
file_data: Dictionary containing file information
|
||||
|
||||
Returns:
|
||||
GenericFile: Created file record
|
||||
|
||||
Raises:
|
||||
Exception: If database operation fails
|
||||
"""
|
||||
db_logger.debug(f"Creating file record: filename={file_data.get('file_name')}")
|
||||
|
||||
try:
|
||||
db_file = GenericFile(**file_data)
|
||||
self.db.add(db_file)
|
||||
self.db.flush()
|
||||
db_logger.info(f"File record created successfully: {file_data.get('file_name')} (ID: {db_file.id})")
|
||||
return db_file
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create file record: filename={file_data.get('file_name')} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_file_by_id(self, file_id: uuid.UUID) -> Optional[GenericFile]:
|
||||
"""
|
||||
Get a file by its ID.
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file
|
||||
|
||||
Returns:
|
||||
Optional[GenericFile]: File record if found, None otherwise
|
||||
"""
|
||||
db_logger.debug(f"Querying file by ID: file_id={file_id}")
|
||||
|
||||
try:
|
||||
file = self.db.query(GenericFile).filter(
|
||||
and_(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.deleted_at.is_(None)
|
||||
)
|
||||
).first()
|
||||
|
||||
if file:
|
||||
db_logger.debug(f"File found: {file.file_name} (ID: {file_id})")
|
||||
else:
|
||||
db_logger.debug(f"File not found: file_id={file_id}")
|
||||
|
||||
return file
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query file by ID: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_file(self, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]:
|
||||
"""
|
||||
Update file metadata.
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to update
|
||||
update_data: Dictionary containing fields to update
|
||||
|
||||
Returns:
|
||||
Optional[GenericFile]: Updated file record if found, None otherwise
|
||||
"""
|
||||
db_logger.debug(f"Updating file: file_id={file_id}")
|
||||
|
||||
try:
|
||||
file = self.get_file_by_id(file_id)
|
||||
if not file:
|
||||
db_logger.debug(f"File not found for update: file_id={file_id}")
|
||||
return None
|
||||
|
||||
# Update allowed fields
|
||||
for field, value in update_data.items():
|
||||
if hasattr(file, field) and field not in ['id', 'created_by', 'created_at', 'tenant_id']:
|
||||
setattr(file, field, value)
|
||||
|
||||
# Update timestamp
|
||||
file.updated_at = datetime.now()
|
||||
|
||||
self.db.flush()
|
||||
db_logger.info(f"File updated successfully: {file.file_name} (ID: {file_id})")
|
||||
return file
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to update file: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_file(self, file_id: uuid.UUID) -> bool:
|
||||
"""
|
||||
Soft delete a file by setting deleted_at timestamp.
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to delete
|
||||
|
||||
Returns:
|
||||
bool: True if file was deleted, False if not found
|
||||
"""
|
||||
db_logger.debug(f"Soft deleting file: file_id={file_id}")
|
||||
|
||||
try:
|
||||
file = self.get_file_by_id(file_id)
|
||||
if not file:
|
||||
db_logger.debug(f"File not found for deletion: file_id={file_id}")
|
||||
return False
|
||||
|
||||
# Soft delete by setting deleted_at
|
||||
file.deleted_at = datetime.now()
|
||||
file.status = "deleted"
|
||||
file.updated_at = datetime.now()
|
||||
|
||||
self.db.flush()
|
||||
db_logger.info(f"File soft deleted successfully: {file.file_name} (ID: {file_id})")
|
||||
return True
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_files_by_context(
|
||||
self,
|
||||
context: UploadContext,
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
status: Optional[str] = "active",
|
||||
created_by: Optional[uuid.UUID] = None
|
||||
) -> Tuple[int, List[GenericFile]]:
|
||||
"""
|
||||
Get files by context with pagination.
|
||||
|
||||
Args:
|
||||
context: Upload context (avatar, app_icon, etc.)
|
||||
tenant_id: Tenant ID for isolation
|
||||
page: Page number (1-indexed)
|
||||
pagesize: Number of items per page
|
||||
status: File status filter (default: "active")
|
||||
created_by: Optional filter by creator user ID
|
||||
|
||||
Returns:
|
||||
Tuple[int, List[GenericFile]]: Total count and list of files
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Querying files by context: context={context}, tenant_id={tenant_id}, "
|
||||
f"page={page}, pagesize={pagesize}, status={status}"
|
||||
)
|
||||
|
||||
try:
|
||||
query = self.db.query(GenericFile).filter(
|
||||
and_(
|
||||
GenericFile.context == context,
|
||||
GenericFile.tenant_id == tenant_id,
|
||||
GenericFile.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply status filter
|
||||
if status:
|
||||
query = query.filter(GenericFile.status == status)
|
||||
|
||||
# Apply creator filter
|
||||
if created_by:
|
||||
query = query.filter(GenericFile.created_by == created_by)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total files found: {total}")
|
||||
|
||||
# Apply pagination and ordering
|
||||
files = query.order_by(GenericFile.created_at.desc()).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
|
||||
db_logger.info(
|
||||
f"Files query successful: context={context}, total={total}, "
|
||||
f"returned={len(files)}"
|
||||
)
|
||||
|
||||
return total, files
|
||||
except Exception as e:
|
||||
db_logger.error(
|
||||
f"Failed to query files by context: context={context}, "
|
||||
f"tenant_id={tenant_id} - {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def create_file(db: Session, file_data: Dict[str, Any]) -> GenericFile:
|
||||
"""Create a new file record"""
|
||||
return GenericFileRepository(db).create_file(file_data)
|
||||
|
||||
|
||||
def get_file_by_id(db: Session, file_id: uuid.UUID) -> Optional[GenericFile]:
|
||||
"""Get a file by its ID"""
|
||||
return GenericFileRepository(db).get_file_by_id(file_id)
|
||||
|
||||
|
||||
def update_file(db: Session, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]:
|
||||
"""Update file metadata"""
|
||||
return GenericFileRepository(db).update_file(file_id, update_data)
|
||||
|
||||
|
||||
def delete_file(db: Session, file_id: uuid.UUID) -> bool:
|
||||
"""Soft delete a file"""
|
||||
return GenericFileRepository(db).delete_file(file_id)
|
||||
|
||||
|
||||
def get_files_by_context(
|
||||
db: Session,
|
||||
context: UploadContext,
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
status: Optional[str] = "active",
|
||||
created_by: Optional[uuid.UUID] = None
|
||||
) -> Tuple[int, List[GenericFile]]:
|
||||
"""Get files by context with pagination"""
|
||||
return GenericFileRepository(db).get_files_by_context(
|
||||
context, tenant_id, page, pagesize, status, created_by
|
||||
)
|
||||
211
api/app/repositories/knowledge_repository.py
Normal file
211
api/app/repositories/knowledge_repository.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import knowledge_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_knowledges_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query knowledge base (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(f"Query knowledge base in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(Knowledge)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of knowledge base queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(Knowledge, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(f"The knowledge base paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [knowledge_schema.Knowledge.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying knowledge base pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_chunded_knowledgeids(
|
||||
db: Session,
|
||||
filters: list
|
||||
) -> list:
|
||||
"""
|
||||
Query the list of vectorized knowledge base IDs
|
||||
Return: list[UUID] - List of knowledge base IDs
|
||||
"""
|
||||
db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
# Only query the id field
|
||||
query = db.query(Knowledge.id)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Get all IDs
|
||||
items = query.all()
|
||||
db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}")
|
||||
|
||||
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
|
||||
return [item[0] for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_knowledge(db: Session, knowledge: knowledge_schema.KnowledgeCreate) -> Knowledge:
|
||||
db_logger.debug(f"Create a knowledge base record: name={knowledge.name}")
|
||||
|
||||
try:
|
||||
db_knowledge = Knowledge(**knowledge.model_dump())
|
||||
db.add(db_knowledge)
|
||||
db.commit()
|
||||
db_logger.info(f"knowledge base record created successfully: {knowledge.name} (ID: {db_knowledge.id})")
|
||||
return db_knowledge
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a knowledge base record: name={knowledge.name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | None:
|
||||
db_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}")
|
||||
|
||||
try:
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first()
|
||||
if knowledge:
|
||||
db_logger.debug(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})")
|
||||
else:
|
||||
db_logger.debug(f"knowledge base does not exist: knowledge_id={knowledge_id}")
|
||||
return knowledge
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
|
||||
if knowledge:
|
||||
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
|
||||
else:
|
||||
db_logger.debug(f"knowledge base does not exist: name={name}, workspace_id={workspace_id}")
|
||||
return knowledge
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the knowledge base based on the name and workspace_id: name={name}, workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete knowledge base record: knowledge_id={knowledge_id}")
|
||||
|
||||
try:
|
||||
# First, query the knowledge base information for logging purposes
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first()
|
||||
if knowledge:
|
||||
knowledge_name = knowledge.name
|
||||
else:
|
||||
knowledge_name = "unknown"
|
||||
|
||||
result = db.query(Knowledge).filter(Knowledge.id == knowledge_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"knowledge base record deleted successfully: {knowledge_name} (ID: {knowledge_id})")
|
||||
else:
|
||||
db_logger.warning(f"The knowledge base record does not exist, and cannot be deleted: knowledge_id={knowledge_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete knowledge base record: knowledge_id={knowledge_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_total_doc_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表所有doc_num的总和
|
||||
"""
|
||||
db_logger.debug(f"Query total doc_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.doc_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"Total doc_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total doc_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_total_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表所有chunk_num的总和
|
||||
"""
|
||||
db_logger.debug(f"Query total chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"Total chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表所有不同id的数量(知识库总数)
|
||||
"""
|
||||
db_logger.debug(f"Query total knowledge base count by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
count = db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1
|
||||
).count()
|
||||
|
||||
db_logger.info(f"Total knowledge base count query successful: workspace_id={workspace_id}, count={count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
142
api/app/repositories/knowledgeshare_repository.py
Normal file
142
api/app/repositories/knowledgeshare_repository.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.knowledgeshare_model import KnowledgeShare
|
||||
from app.schemas import knowledgeshare_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy import or_
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_knowledgeshares_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query knowledge base sharing (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Query knowledge base sharing in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(KnowledgeShare)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of knowledge base sharing queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(KnowledgeShare, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(f"The knowledge base sharing paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [knowledgeshare_schema.KnowledgeShare.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying knowledge base sharing pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_source_kb_ids_by_target_kb_id(
|
||||
db: Session,
|
||||
filters: list
|
||||
) -> list:
|
||||
"""
|
||||
Query the original knowledge base ID list by sharing the knowledge base
|
||||
Return: list[UUID] - List of knowledge base IDs
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
# Only query the id field
|
||||
query = db.query(KnowledgeShare.source_kb_id)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Get all IDs
|
||||
items = query.all()
|
||||
db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}")
|
||||
|
||||
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
|
||||
return [item[0] for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_knowledgeshare(db: Session, knowledgeshare: knowledgeshare_schema.KnowledgeShareCreate) -> KnowledgeShare:
|
||||
db_logger.debug(f"Create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id}")
|
||||
|
||||
try:
|
||||
db_knowledgeshare = KnowledgeShare(**knowledgeshare.model_dump())
|
||||
db.add(db_knowledgeshare)
|
||||
db.commit()
|
||||
db_logger.info(f"knowledge base sharing record created successfully: (ID: {db_knowledgeshare.id})")
|
||||
return db_knowledgeshare
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID) -> KnowledgeShare | None:
|
||||
db_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}")
|
||||
|
||||
try:
|
||||
knowledgeshare = db.query(KnowledgeShare).filter(
|
||||
or_(
|
||||
KnowledgeShare.id == knowledgeshare_id,
|
||||
KnowledgeShare.target_kb_id == knowledgeshare_id
|
||||
)
|
||||
).first()
|
||||
if knowledgeshare:
|
||||
db_logger.debug(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})")
|
||||
else:
|
||||
db_logger.debug(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
|
||||
return knowledgeshare
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the knowledge base sharing based on the ID: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id}")
|
||||
|
||||
try:
|
||||
result = db.query(KnowledgeShare).filter(
|
||||
or_(
|
||||
KnowledgeShare.id == knowledgeshare_id,
|
||||
KnowledgeShare.target_kb_id == knowledgeshare_id
|
||||
)
|
||||
).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"knowledge base sharing record deleted successfully: (ID: {knowledgeshare_id})")
|
||||
else:
|
||||
db_logger.warning(f"The knowledge base sharing record does not exist, and cannot be deleted: knowledgeshare_id={knowledgeshare_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
110
api/app/repositories/memory_increment_repository.py
Normal file
110
api/app/repositories/memory_increment_repository.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from app.models.memory_increment_model import MemoryIncrement
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class MemoryIncrementRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_memory_increments_by_workspace_id(self, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]:
|
||||
"""根据工作空间ID查询内存增量:通过 MemoryIncrement 关联查询 MemoryIncrement 列表"""
|
||||
try:
|
||||
# 使用窗口函数按日期分区并排序
|
||||
subquery = (
|
||||
self.db.query(
|
||||
MemoryIncrement,
|
||||
func.row_number().over(
|
||||
partition_by=func.date(MemoryIncrement.created_at), # 按日期分区
|
||||
order_by=MemoryIncrement.created_at.desc() # 按时间戳升序排序
|
||||
).label('row_num')
|
||||
)
|
||||
.filter(MemoryIncrement.workspace_id == workspace_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
memory_increment_alias = aliased(MemoryIncrement, subquery)
|
||||
|
||||
memory_increments = (
|
||||
self.db.query(memory_increment_alias)
|
||||
.filter(subquery.c.row_num == 1) # 只取每个日期的第一条(最新的)
|
||||
.order_by(memory_increment_alias.created_at.asc()) # 按时间戳降序排序
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的内存增量")
|
||||
return memory_increments
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下内存增量时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_latest_memory_increment_by_workspace_id(self, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]:
|
||||
"""根据工作空间ID查询最新的内存增量记录"""
|
||||
try:
|
||||
memory_increment = (
|
||||
self.db.query(MemoryIncrement)
|
||||
.filter(MemoryIncrement.workspace_id == workspace_id)
|
||||
.order_by(MemoryIncrement.created_at.desc(), MemoryIncrement.id.desc())
|
||||
.first()
|
||||
)
|
||||
if memory_increment:
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的最新内存增量")
|
||||
else:
|
||||
db_logger.warning(f"未找到工作空间 {workspace_id} 下的内存增量记录")
|
||||
return memory_increment
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下最新内存增量时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def write_memory_increment(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
total_num: int
|
||||
) -> MemoryIncrement:
|
||||
"""写入内存增量"""
|
||||
try:
|
||||
memory_increment = MemoryIncrement(
|
||||
workspace_id=workspace_id,
|
||||
total_num=total_num,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now()
|
||||
)
|
||||
self.db.add(memory_increment)
|
||||
self.db.commit()
|
||||
self.db.refresh(memory_increment)
|
||||
db_logger.info(f"成功写入内存增量: workspace_id={workspace_id}, total_num={total_num}")
|
||||
return memory_increment
|
||||
except Exception as e:
|
||||
db_logger.error(f"写入内存增量失败: workspace_id={workspace_id}, total_num={total_num} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_memory_increments_by_workspace_id(db: Session, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]:
|
||||
"""根据工作空间ID查询内存增量(返回 MemoryIncrement ORM 列表)"""
|
||||
repo = MemoryIncrementRepository(db)
|
||||
memory_increments = repo.get_memory_increments_by_workspace_id(workspace_id, limit)
|
||||
return memory_increments
|
||||
|
||||
def write_memory_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
total_num: int
|
||||
) -> MemoryIncrement:
|
||||
"""写入内存增量"""
|
||||
repo = MemoryIncrementRepository(db)
|
||||
memory_increment = repo.write_memory_increment(workspace_id, total_num)
|
||||
return memory_increment
|
||||
|
||||
def get_latest_memory_increment_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]:
|
||||
"""根据工作空间ID查询最新的内存增量记录"""
|
||||
repo = MemoryIncrementRepository(db)
|
||||
return repo.get_latest_memory_increment_by_workspace_id(workspace_id)
|
||||
386
api/app/repositories/model_repository.py
Normal file
386
api/app/repositories/model_repository.py
Normal file
@@ -0,0 +1,386 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class ModelConfigRepository:
|
||||
"""模型配置Repository"""
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, model_id: uuid.UUID) -> Optional[ModelConfig]:
|
||||
"""根据ID获取模型配置"""
|
||||
db_logger.debug(f"根据ID查询模型配置: model_id={model_id}")
|
||||
|
||||
try:
|
||||
model = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
).filter(ModelConfig.id == model_id).first()
|
||||
|
||||
if model:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name} (ID: {model_id})")
|
||||
else:
|
||||
db_logger.debug(f"模型配置不存在: model_id={model_id}")
|
||||
return model
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(db: Session, name: str) -> Optional[ModelConfig]:
|
||||
"""根据名称获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}")
|
||||
|
||||
try:
|
||||
model = db.query(ModelConfig).filter(ModelConfig.name == name).first()
|
||||
if model:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name}")
|
||||
return model
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def search_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表
|
||||
|
||||
Args:
|
||||
name: 模型名称关键词(模糊匹配)
|
||||
limit: 返回数量上限
|
||||
Returns:
|
||||
模型配置列表
|
||||
"""
|
||||
db_logger.debug(f"按名称模糊查询模型配置: name~{name}, limit={limit}")
|
||||
try:
|
||||
models = (
|
||||
db.query(ModelConfig)
|
||||
.filter(ModelConfig.name.ilike(f"%{name}%"))
|
||||
.order_by(ModelConfig.name)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
db_logger.debug(f"模糊查询成功: 返回数量={len(models)}")
|
||||
return models
|
||||
except Exception as e:
|
||||
db_logger.error(f"按名称模糊查询模型配置失败: name~{name} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_list(db: Session, query: ModelConfigQuery) -> Tuple[List[ModelConfig], int]:
|
||||
"""获取模型配置列表"""
|
||||
db_logger.debug(f"查询模型配置列表: {query.dict()}")
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
filters = []
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
if query.type:
|
||||
filters.append(ModelConfig.type.in_(query.type))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
|
||||
if query.is_public is not None:
|
||||
filters.append(ModelConfig.is_public == query.is_public)
|
||||
|
||||
if query.search:
|
||||
# 搜索逻辑需要join ModelApiKey表来搜索model_name
|
||||
search_filter = or_(
|
||||
ModelConfig.name.ilike(f"%{query.search}%"),
|
||||
# ModelConfig.description.ilike(f"%{query.search}%")
|
||||
)
|
||||
filters.append(search_filter)
|
||||
|
||||
# 构建基础查询
|
||||
base_query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
)
|
||||
|
||||
# 如果需要按provider筛选,需要join ModelApiKey表
|
||||
if query.provider:
|
||||
base_query = base_query.join(ModelApiKey).filter(
|
||||
ModelApiKey.provider == query.provider
|
||||
).distinct()
|
||||
|
||||
if filters:
|
||||
base_query = base_query.filter(and_(*filters))
|
||||
|
||||
# 获取总数
|
||||
total = base_query.count()
|
||||
|
||||
# 分页查询
|
||||
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
|
||||
(query.page - 1) * query.pagesize
|
||||
).limit(query.pagesize).all()
|
||||
|
||||
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
|
||||
return models, total
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_type(db: Session, model_type: ModelType, is_active: bool = True) -> List[ModelConfig]:
|
||||
"""根据类型获取模型配置"""
|
||||
db_logger.debug(f"根据类型查询模型配置: type={model_type}, is_active={is_active}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
).filter(ModelConfig.type == model_type)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active == True)
|
||||
|
||||
models = query.order_by(ModelConfig.name).all()
|
||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, model_data: dict) -> ModelConfig:
|
||||
"""创建模型配置"""
|
||||
db_logger.debug(f"创建模型配置: {model_data.get('name')}")
|
||||
|
||||
try:
|
||||
db_model = ModelConfig(**model_data)
|
||||
db.add(db_model)
|
||||
|
||||
db_logger.info(f"模型配置已添加到会话: {db_model.name}")
|
||||
return db_model
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建模型配置失败: {model_data.get('name')} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> Optional[ModelConfig]:
|
||||
"""更新模型配置"""
|
||||
db_logger.debug(f"更新模型配置: model_id={model_id}")
|
||||
|
||||
try:
|
||||
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
|
||||
if not db_model:
|
||||
db_logger.warning(f"模型配置不存在: model_id={model_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = model_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_model, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_model)
|
||||
|
||||
db_logger.info(f"模型配置更新成功: {db_model.name} (ID: {model_id})")
|
||||
return db_model
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, model_id: uuid.UUID) -> bool:
|
||||
"""删除模型配置"""
|
||||
db_logger.debug(f"删除模型配置: model_id={model_id}")
|
||||
|
||||
try:
|
||||
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
|
||||
if not db_model:
|
||||
db_logger.warning(f"模型配置不存在: model_id={model_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_model)
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_stats(db: Session) -> Dict[str, Any]:
|
||||
"""获取模型统计信息"""
|
||||
db_logger.debug("获取模型统计信息")
|
||||
|
||||
try:
|
||||
# 总数统计
|
||||
total_models = db.query(ModelConfig).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
|
||||
|
||||
# 按类型统计
|
||||
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
||||
embedding_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.EMBEDDING).count()
|
||||
rerank_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.RERANK).count()
|
||||
|
||||
# 按提供商统计 - 现在从ModelApiKey表获取
|
||||
provider_stats = {}
|
||||
provider_results = db.query(
|
||||
ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id))
|
||||
).group_by(ModelApiKey.provider).all()
|
||||
|
||||
for provider, count in provider_results:
|
||||
provider_stats[provider.value] = count
|
||||
|
||||
stats = {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"llm_count": llm_count,
|
||||
"embedding_count": embedding_count,
|
||||
"rerank_count": rerank_count,
|
||||
"provider_stats": provider_stats
|
||||
}
|
||||
|
||||
db_logger.debug(f"模型统计信息获取成功: {stats}")
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"获取模型统计信息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class ModelApiKeyRepository:
|
||||
"""模型API Key Repository"""
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ModelApiKey]:
|
||||
"""根据ID获取API Key"""
|
||||
db_logger.debug(f"根据ID查询API Key: api_key_id={api_key_id}")
|
||||
|
||||
try:
|
||||
api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
|
||||
if api_key:
|
||||
db_logger.debug(f"API Key查询成功: {api_key.model_name} (ID: {api_key_id})")
|
||||
return api_key
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_model_config(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
|
||||
"""根据模型配置ID获取API Key列表"""
|
||||
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelApiKey.is_active == True)
|
||||
|
||||
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
||||
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
||||
return api_keys
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据模型配置ID查询API Key失败: model_config_id={model_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
"""创建API Key"""
|
||||
db_logger.debug(f"创建API Key: {api_key_data.provider}")
|
||||
|
||||
try:
|
||||
db_api_key = ModelApiKey(**api_key_data.dict())
|
||||
db.add(db_api_key)
|
||||
|
||||
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
|
||||
return db_api_key
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建API Key失败: {api_key_data.provider} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> Optional[ModelApiKey]:
|
||||
"""更新API Key"""
|
||||
db_logger.debug(f"更新API Key: api_key_id={api_key_id}")
|
||||
|
||||
try:
|
||||
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
|
||||
if not db_api_key:
|
||||
db_logger.warning(f"API Key不存在: api_key_id={api_key_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = api_key_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_api_key, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
db_logger.info(f"API Key更新成功: {db_api_key.model_name} (ID: {api_key_id})")
|
||||
return db_api_key
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""删除API Key"""
|
||||
db_logger.debug(f"删除API Key: api_key_id={api_key_id}")
|
||||
|
||||
try:
|
||||
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
|
||||
if not db_api_key:
|
||||
db_logger.warning(f"API Key不存在: api_key_id={api_key_id}")
|
||||
return False
|
||||
|
||||
db.delete(db_api_key)
|
||||
db.commit()
|
||||
|
||||
db_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""更新API Key使用统计"""
|
||||
db_logger.debug(f"更新API Key使用统计: api_key_id={api_key_id}")
|
||||
|
||||
try:
|
||||
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
|
||||
if not db_api_key:
|
||||
return False
|
||||
|
||||
# 更新使用次数和最后使用时间
|
||||
current_count = int(db_api_key.usage_count or "0")
|
||||
db_api_key.usage_count = str(current_count + 1)
|
||||
db_api_key.last_used_at = func.now()
|
||||
|
||||
db.commit()
|
||||
db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
32
api/app/repositories/neo4j/__init__.py
Normal file
32
api/app/repositories/neo4j/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储模块
|
||||
|
||||
本模块包含Neo4j图数据库的仓储实现,用于管理知识图谱的节点和边。
|
||||
|
||||
Modules:
|
||||
neo4j_connector: Neo4j数据库连接器
|
||||
base_neo4j_repository: Neo4j仓储基类
|
||||
dialog_repository: 对话仓储
|
||||
statement_repository: 陈述句仓储
|
||||
entity_repository: 实体仓储
|
||||
cypher_queries: Cypher查询语句
|
||||
graph_search: 图搜索功能
|
||||
graph_saver: 图数据保存功能
|
||||
add_nodes: 添加节点功能
|
||||
add_edges: 添加边功能
|
||||
create_indexes: 创建索引功能
|
||||
"""
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.repositories.neo4j.dialog_repository import DialogRepository
|
||||
from app.repositories.neo4j.statement_repository import StatementRepository
|
||||
from app.repositories.neo4j.entity_repository import EntityRepository
|
||||
|
||||
__all__ = [
|
||||
'Neo4jConnector',
|
||||
'BaseNeo4jRepository',
|
||||
'DialogRepository',
|
||||
'StatementRepository',
|
||||
'EntityRepository',
|
||||
]
|
||||
102
api/app/repositories/neo4j/add_edges.py
Normal file
102
api/app/repositories/neo4j/add_edges.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List, Optional
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE
|
||||
from app.core.memory.models.message_models import Chunk
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
|
||||
async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add edges between chunk nodes and their statement nodes in Neo4j.
|
||||
|
||||
Args:
|
||||
chunks: List of Chunk objects containing the statements
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunks provided to create edges")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Build edges deterministically per (chunk, statement) pair
|
||||
edges: List[dict] = []
|
||||
for chunk in chunks:
|
||||
for stmt in getattr(chunk, "statements", []) or []:
|
||||
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
|
||||
edge = {
|
||||
"id": stable_edge_id,
|
||||
"source": chunk.id,
|
||||
"target": stmt.id,
|
||||
"group_id": getattr(stmt, 'group_id', None),
|
||||
"user_id":getattr(stmt, 'user_id', None),
|
||||
"apply_id": getattr(stmt, 'apply_id', None),
|
||||
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
||||
"created_at": getattr(stmt, 'created_at', None),
|
||||
"expired_at": getattr(stmt, 'expired_at', None),
|
||||
# "created_at": getattr(statement, 'created_at', None),
|
||||
# "expired_at": None # Set to None or appropriate default
|
||||
}
|
||||
edges.append(edge)
|
||||
|
||||
if not edges:
|
||||
print("No statements found in chunks to create edges")
|
||||
return []
|
||||
|
||||
# Execute the query to create edges
|
||||
result = await connector.execute_query(
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
chunk_statement_edges=edges
|
||||
)
|
||||
created_uuids = [record.get("uuid") for record in result] if result else []
|
||||
print(f"Successfully created {len(created_uuids)} chunk-statement edges")
|
||||
return created_uuids
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk-statement edges: {e}")
|
||||
return None
|
||||
|
||||
async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Create edges from MemorySummary to Statements via their chunk_ids.
|
||||
|
||||
For each summary and each chunk_id in it, this links the summary to all statements
|
||||
contained in that chunk using DERIVED_FROM_STATEMENT. This supports queries like
|
||||
summary -> statement -> entity with minimal hops.
|
||||
|
||||
Args:
|
||||
summaries: List of MemorySummaryNode objects
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge elementIds or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
return []
|
||||
|
||||
try:
|
||||
edges: List[dict] = []
|
||||
for s in summaries:
|
||||
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
||||
edges.append({
|
||||
"summary_id": s.id,
|
||||
"chunk_id": chunk_id,
|
||||
"group_id": s.group_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
})
|
||||
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||
edges=edges
|
||||
)
|
||||
created = [record.get("uuid") for record in result] if result else []
|
||||
return created
|
||||
except Exception:
|
||||
return None
|
||||
215
api/app/repositories/neo4j/add_nodes.py
Normal file
215
api/app/repositories/neo4j/add_nodes.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
|
||||
print(f"All group_id: {group_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
Args:
|
||||
dialogues: List of DialogueNode objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not dialogues:
|
||||
print("No dialogues to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Flatten DialogueNode objects to match Cypher expected fields
|
||||
flattened_dialogues = []
|
||||
for dialogue in dialogues:
|
||||
flattened_dialogues.append({
|
||||
"id": dialogue.id,
|
||||
"group_id": dialogue.group_id,
|
||||
"user_id": dialogue.user_id,
|
||||
"apply_id": dialogue.apply_id,
|
||||
"run_id": dialogue.run_id,
|
||||
"ref_id": dialogue.ref_id,
|
||||
"name": dialogue.name,
|
||||
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
|
||||
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
|
||||
"content": dialogue.content,
|
||||
"dialog_embedding": dialogue.dialog_embedding
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
DIALOGUE_NODE_SAVE,
|
||||
dialogues=flattened_dialogues
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add statement nodes to Neo4j database.
|
||||
|
||||
Args:
|
||||
statements: List of StatementNode objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not statements:
|
||||
print("No statements to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Flatten StatementNode objects to only include primitive types
|
||||
flattened_statements = []
|
||||
for statement in statements:
|
||||
flattened_statement = {
|
||||
"id": statement.id,
|
||||
"name": statement.name,
|
||||
"group_id": statement.group_id,
|
||||
"user_id": statement.user_id,
|
||||
"apply_id": statement.apply_id,
|
||||
"run_id": statement.run_id,
|
||||
"chunk_id": statement.chunk_id,
|
||||
# "created_at": statement.created_at.isoformat(),
|
||||
"created_at": statement.created_at.isoformat() if statement.created_at else None,
|
||||
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
|
||||
"stmt_type": statement.stmt_type,
|
||||
"temporal_info": statement.temporal_info.value,
|
||||
"statement": statement.statement,
|
||||
"connect_strength": statement.connect_strength,
|
||||
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
|
||||
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
|
||||
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
|
||||
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
|
||||
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
|
||||
# "triplet_extraction_info": json.dumps({
|
||||
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
|
||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
result = await connector.execute_query(
|
||||
STATEMENT_NODE_SAVE,
|
||||
statements=flattened_statements
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
chunks: List of ChunkNode objects to add
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created chunk UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunk nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convert chunk nodes to dictionaries for the query
|
||||
flattened_chunks = []
|
||||
for chunk in chunks:
|
||||
# Flatten metadata properties to avoid Neo4j Map type issues
|
||||
metadata = chunk.metadata if chunk.metadata else {}
|
||||
flattened_chunk = {
|
||||
"id": chunk.id,
|
||||
"name": chunk.name,
|
||||
"group_id": chunk.group_id,
|
||||
"user_id": chunk.user_id,
|
||||
"apply_id": chunk.apply_id,
|
||||
"run_id": chunk.run_id,
|
||||
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
||||
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
||||
"dialog_id": chunk.dialog_id,
|
||||
"content": chunk.content,
|
||||
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
||||
"sequence_number": chunk.sequence_number,
|
||||
"start_index": metadata.get("start_index"),
|
||||
"end_index": metadata.get("end_index")
|
||||
}
|
||||
flattened_chunks.append(flattened_chunk)
|
||||
|
||||
result = await connector.execute_query(
|
||||
CHUNK_NODE_SAVE,
|
||||
chunks=flattened_chunks
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
summaries: List of MemorySummaryNode objects to add
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created summary node ids or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
print("No memory summary nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
flattened = []
|
||||
for s in summaries:
|
||||
flattened.append({
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"group_id": s.group_id,
|
||||
"user_id": s.user_id,
|
||||
"apply_id": s.apply_id,
|
||||
"run_id": s.run_id,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
||||
"dialog_id": s.dialog_id,
|
||||
"chunk_ids": s.chunk_ids,
|
||||
"content": s.content,
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_NODE_SAVE,
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
return created_ids
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
175
api/app/repositories/neo4j/base_neo4j_repository.py
Normal file
175
api/app/repositories/neo4j/base_neo4j_repository.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储基类模块
|
||||
|
||||
本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。
|
||||
|
||||
Classes:
|
||||
BaseNeo4jRepository: Neo4j仓储基类,实现通用的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, TypeVar
|
||||
from app.repositories.base_repository import BaseRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseNeo4jRepository(BaseRepository[T]):
|
||||
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
|
||||
|
||||
这个基类封装了Neo4j节点的通用CRUD操作,子类只需要实现
|
||||
特定的映射逻辑和业务查询方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签(如"Dialogue", "Statement"等)
|
||||
|
||||
Type Parameters:
|
||||
T: 实体类型,通常是Pydantic模型
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector, node_label: str):
|
||||
"""初始化Neo4j仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,用于Cypher查询
|
||||
"""
|
||||
self.connector = connector
|
||||
self.node_label = node_label
|
||||
|
||||
async def create(self, entity: T) -> T:
|
||||
"""创建节点
|
||||
|
||||
将实体对象转换为Neo4j节点并保存到数据库。
|
||||
|
||||
Args:
|
||||
entity: 要创建的实体对象
|
||||
|
||||
Returns:
|
||||
T: 创建后的实体对象
|
||||
|
||||
Example:
|
||||
>>> dialog = DialogueNode(id="123", name="对话1", ...)
|
||||
>>> created = await repository.create(dialog)
|
||||
"""
|
||||
query = f"""
|
||||
CREATE (n:{self.node_label} $props)
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(
|
||||
query,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def get_by_id(self, entity_id: str) -> Optional[T]:
|
||||
"""根据ID获取节点
|
||||
|
||||
Args:
|
||||
entity_id: 节点ID
|
||||
|
||||
Returns:
|
||||
Optional[T]: 找到的实体对象,如果不存在则返回None
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
if result:
|
||||
return self._map_to_entity(result[0])
|
||||
return None
|
||||
|
||||
async def update(self, entity: T) -> T:
|
||||
"""更新节点
|
||||
|
||||
更新现有节点的属性。使用SET +=语法合并属性。
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体对象(必须包含id字段)
|
||||
|
||||
Returns:
|
||||
T: 更新后的实体对象
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
SET n += $props
|
||||
RETURN n
|
||||
"""
|
||||
await self.connector.execute_query(
|
||||
query,
|
||||
id=entity.id,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity_id: str) -> bool:
|
||||
"""删除节点
|
||||
|
||||
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
|
||||
|
||||
Args:
|
||||
entity_id: 要删除的节点ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
DETACH DELETE n
|
||||
RETURN count(n) as deleted
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
return result[0]['deleted'] > 0 if result else False
|
||||
|
||||
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
|
||||
"""查询节点
|
||||
|
||||
根据过滤条件查询节点列表。
|
||||
|
||||
Args:
|
||||
filters: 查询条件字典,键为属性名,值为期望的值
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[T]: 符合条件的实体列表
|
||||
|
||||
Example:
|
||||
>>> results = await repository.find(
|
||||
... {"group_id": "group_123", "user_id": "user_456"},
|
||||
... limit=50
|
||||
... )
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
|
||||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
limit=limit,
|
||||
**filters
|
||||
)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> T:
|
||||
"""将节点数据映射为实体对象
|
||||
|
||||
这是一个抽象方法,子类必须实现具体的映射逻辑。
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
T: 映射后的实体对象
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 如果子类未实现此方法
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _map_to_entity method")
|
||||
332
api/app/repositories/neo4j/create_indexes.py
Normal file
332
api/app/repositories/neo4j/create_indexes.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Full-Text Indexes (for keyword search)")
|
||||
print("=" * 70)
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: statementsFulltext")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
# """)
|
||||
|
||||
# 创建 Entities 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: entitiesFulltext")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: chunksFulltext")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: summariesFulltext")
|
||||
|
||||
print("\nFull-text indexes created successfully with BM25 support.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating full-text indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
Vector indexes provide 10-100x faster similarity search compared to manual cosine calculation.
|
||||
This is critical for performance - reduces embedding search from ~1.4s to ~0.05-0.2s!
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Vector Indexes (for embedding search)")
|
||||
print("=" * 70)
|
||||
print("Note: Adjust vector.dimensions if using different embedding model")
|
||||
print(" Current setting: 1024 dimensions (for bge-m3)")
|
||||
print()
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||
FOR (s:Statement)
|
||||
ON s.statement_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: statement_embedding_index")
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||
FOR (c:Chunk)
|
||||
ON c.chunk_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: chunk_embedding_index")
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity)
|
||||
ON e.name_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: entity_embedding_index")
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary)
|
||||
ON m.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: summary_embedding_index")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
FOR (d:Dialogue)
|
||||
ON d.dialog_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: dialogue_embedding_index")
|
||||
|
||||
print("\nVector indexes created successfully!")
|
||||
print("\nExpected performance improvement:")
|
||||
print(" Before: ~1.4s for embedding search")
|
||||
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating vector indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_config_id_indexes():
|
||||
"""Create indexes on config_id fields for improved query performance.
|
||||
|
||||
These indexes enable fast filtering of nodes by configuration ID,
|
||||
which is essential for configuration isolation and multi-tenant scenarios.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Config ID Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
|
||||
FOR (d:Dialogue) ON (d.config_id)
|
||||
""")
|
||||
print("✓ Created: dialogue_config_id_index")
|
||||
|
||||
# Statement.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX statement_config_id_index IF NOT EXISTS
|
||||
FOR (s:Statement) ON (s.config_id)
|
||||
""")
|
||||
print("✓ Created: statement_config_id_index")
|
||||
|
||||
# ExtractedEntity.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX entity_config_id_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity) ON (e.config_id)
|
||||
""")
|
||||
print("✓ Created: entity_config_id_index")
|
||||
|
||||
# MemorySummary.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX summary_config_id_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary) ON (m.config_id)
|
||||
""")
|
||||
print("✓ Created: summary_config_id_index")
|
||||
|
||||
print("\nConfig ID indexes created successfully!")
|
||||
print("These indexes enable fast filtering by configuration ID.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating config_id indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Unique Constraints")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT dialog_id_unique IF NOT EXISTS
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: dialog_id_unique")
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT statement_id_unique IF NOT EXISTS
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: statement_id_unique")
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
CREATE CONSTRAINT chunk_id_unique IF NOT EXISTS
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: chunk_id_unique")
|
||||
|
||||
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating unique constraints: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
print("\n" + "=" * 70)
|
||||
print("Neo4j Index & Constraint Setup")
|
||||
print("=" * 70)
|
||||
print("This will create:")
|
||||
print(" 1. Full-text indexes (for keyword/BM25 search)")
|
||||
print(" 2. Vector indexes (for embedding similarity search)")
|
||||
print(" 3. Config ID indexes (for configuration isolation)")
|
||||
print(" 4. Unique constraints (for data integrity)")
|
||||
print("=" * 70)
|
||||
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_config_id_indexes()
|
||||
await create_unique_constraints()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
print("=" * 70)
|
||||
print("\nTo verify, run in Neo4j Browser:")
|
||||
print(" SHOW INDEXES")
|
||||
print(" SHOW CONSTRAINTS")
|
||||
print()
|
||||
|
||||
|
||||
async def check_indexes():
|
||||
"""Check what indexes currently exist."""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Checking Existing Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
query = "SHOW INDEXES"
|
||||
result = await connector.execute_query(query)
|
||||
|
||||
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
|
||||
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
|
||||
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
|
||||
|
||||
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
|
||||
for idx in fulltext_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nVector indexes: {len(vector_indexes)}")
|
||||
for idx in vector_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
|
||||
for idx in range_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
if not vector_indexes:
|
||||
print("\n⚠️ WARNING: No vector indexes found!")
|
||||
print(" Embedding search will be VERY SLOW (~1.4s)")
|
||||
print(" Run: python create_indexes.py")
|
||||
|
||||
# Check for config_id indexes
|
||||
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
|
||||
if len(config_id_indexes) < 4:
|
||||
print("\n⚠️ WARNING: Not all config_id indexes found!")
|
||||
print(f" Expected 4, found {len(config_id_indexes)}")
|
||||
print(" Run: python create_indexes.py config_id")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
if command == "check":
|
||||
asyncio.run(check_indexes())
|
||||
elif command == "fulltext":
|
||||
asyncio.run(create_fulltext_indexes())
|
||||
elif command == "vector":
|
||||
asyncio.run(create_vector_indexes())
|
||||
elif command == "config_id":
|
||||
asyncio.run(create_config_id_indexes())
|
||||
elif command == "constraints":
|
||||
asyncio.run(create_unique_constraints())
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("\nUsage:")
|
||||
print(" python create_indexes.py # Create all indexes")
|
||||
print(" python create_indexes.py check # Check existing indexes")
|
||||
print(" python create_indexes.py fulltext # Create only full-text indexes")
|
||||
print(" python create_indexes.py vector # Create only vector indexes")
|
||||
print(" python create_indexes.py config_id # Create only config_id indexes")
|
||||
print(" python create_indexes.py constraints # Create only constraints")
|
||||
else:
|
||||
asyncio.run(create_all_indexes())
|
||||
|
||||
684
api/app/repositories/neo4j/cypher_queries.py
Normal file
684
api/app/repositories/neo4j/cypher_queries.py
Normal file
@@ -0,0 +1,684 @@
|
||||
|
||||
DIALOGUE_NODE_SAVE = """
|
||||
UNWIND $dialogues AS dialogue
|
||||
MERGE (n:Dialogue {id: dialogue.id})
|
||||
SET n.uuid = coalesce(n.uuid, dialogue.id),
|
||||
n.group_id = dialogue.group_id,
|
||||
n.user_id = dialogue.user_id,
|
||||
n.apply_id = dialogue.apply_id,
|
||||
n.run_id = dialogue.run_id,
|
||||
n.ref_id = dialogue.ref_id,
|
||||
n.created_at = dialogue.created_at,
|
||||
n.expired_at = dialogue.expired_at,
|
||||
n.content = dialogue.content,
|
||||
n.dialog_embedding = dialogue.dialog_embedding
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
STATEMENT_NODE_SAVE = """
|
||||
UNWIND $statements AS statement
|
||||
MERGE (s:Statement {id: statement.id})
|
||||
SET s += {
|
||||
id: statement.id,
|
||||
group_id: statement.group_id,
|
||||
user_id: statement.user_id,
|
||||
apply_id: statement.apply_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
run_id: statement.run_id,
|
||||
created_at: statement.created_at,
|
||||
expired_at: statement.expired_at,
|
||||
stmt_type: statement.stmt_type,
|
||||
temporal_info: statement.temporal_info,
|
||||
relevence_info: statement.relevence_info,
|
||||
statement: statement.statement,
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
|
||||
CHUNK_NODE_SAVE = """
|
||||
UNWIND $chunks AS chunk
|
||||
MERGE (c:Chunk {id: chunk.id})
|
||||
SET c += {
|
||||
id: chunk.id,
|
||||
name: chunk.name,
|
||||
group_id: chunk.group_id,
|
||||
user_id: chunk.user_id,
|
||||
apply_id: chunk.apply_id,
|
||||
run_id: chunk.run_id,
|
||||
created_at: chunk.created_at,
|
||||
expired_at: chunk.expired_at,
|
||||
dialog_id: chunk.dialog_id,
|
||||
content: chunk.content,
|
||||
chunk_embedding: chunk.chunk_embedding,
|
||||
sequence_number: chunk.sequence_number,
|
||||
start_index: chunk.start_index,
|
||||
end_index: chunk.end_index
|
||||
}
|
||||
RETURN c.id AS uuid
|
||||
"""
|
||||
# bug修改点
|
||||
|
||||
EXTRACTED_ENTITY_NODE_SAVE = """
|
||||
// Upsert entity nodes safely: preserve existing non-empty fields when incoming is empty
|
||||
UNWIND $entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id})
|
||||
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
|
||||
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
|
||||
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
|
||||
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
|
||||
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
|
||||
e.created_at = CASE
|
||||
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
||||
THEN entity.created_at ELSE e.created_at END,
|
||||
e.expired_at = CASE
|
||||
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
|
||||
THEN entity.expired_at ELSE e.expired_at END,
|
||||
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
||||
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
||||
e.description = CASE
|
||||
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
||||
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
||||
THEN entity.description ELSE e.description END,
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
|
||||
ELSE e.aliases END,
|
||||
e.name_embedding = CASE
|
||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||
ELSE e.name_embedding END,
|
||||
e.fact_summary = CASE
|
||||
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||
THEN entity.fact_summary ELSE e.fact_summary END,
|
||||
e.connect_strength = CASE
|
||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||
ELSE CASE
|
||||
WHEN e.connect_strength = 'strong' AND entity.connect_strength = 'weak' THEN 'both'
|
||||
WHEN e.connect_strength = 'weak' AND entity.connect_strength = 'strong' THEN 'both'
|
||||
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
|
||||
ELSE e.connect_strength
|
||||
END
|
||||
END
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
|
||||
ENTITY_RELATIONSHIP_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Match entities by stable id within group, do not constrain by run_id
|
||||
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
|
||||
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
|
||||
// Avoid duplicate edges across runs for the same endpoints
|
||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||
SET r.predicate = rel.predicate,
|
||||
r.statement_id = rel.statement_id,
|
||||
r.value = rel.value,
|
||||
r.statement = rel.statement,
|
||||
r.valid_at = rel.valid_at,
|
||||
r.invalid_at = rel.invalid_at,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
r.run_id = rel.run_id,
|
||||
r.group_id = rel.group_id
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||
|
||||
# 保存弱关系实体,设置 e.is_weak = true;不维护 e.relations 聚合字段
|
||||
WEAK_ENTITY_NODE_SAVE = """
|
||||
UNWIND $weak_entities AS entity
|
||||
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
||||
SET e += {
|
||||
name: entity.name,
|
||||
group_id: entity.group_id,
|
||||
run_id: entity.run_id,
|
||||
description: entity.description,
|
||||
chunk_id: entity.chunk_id,
|
||||
dialog_id: entity.dialog_id
|
||||
}
|
||||
// Independent weak flag,仅标记弱关系,不再维护 relations 聚合字段
|
||||
SET e.is_weak = true
|
||||
RETURN e.id AS id
|
||||
"""
|
||||
|
||||
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true,不维护 e.relations 字段
|
||||
SAVE_STRONG_TRIPLE_ENTITIES = """
|
||||
UNWIND $items AS item
|
||||
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
||||
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET s.is_strong = true
|
||||
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
||||
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
|
||||
// Independent strong flag
|
||||
SET o.is_strong = true
|
||||
"""
|
||||
|
||||
|
||||
DIALOGUE_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $dialogue_statement_edges AS edge
|
||||
// 支持按 uuid 或 ref_id 连接到 Dialogue,避免因来源 ID 不一致而断链
|
||||
MATCH (dialogue:Dialogue)
|
||||
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
|
||||
MATCH (statement:Statement {id: edge.target})
|
||||
// 仅按端点去重,关系属性可更新
|
||||
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
||||
SET e.uuid = edge.id,
|
||||
e.group_id = edge.group_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||
|
||||
|
||||
CHUNK_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $chunk_statement_edges AS edge
|
||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
|
||||
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
||||
SET e.group_id = edge.group_id,
|
||||
e.run_id = edge.run_id,
|
||||
e.created_at = edge.created_at,
|
||||
e.expired_at = edge.expired_at
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
STATEMENT_ENTITY_EDGE_SAVE = """
|
||||
UNWIND $relationships AS rel
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
// Statement nodes are per-run; keep run_id constraint on statements
|
||||
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
|
||||
// Entities are shared across runs within a group; do not constrain by run_id
|
||||
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
|
||||
// Avoid duplicate edges across runs for same endpoints
|
||||
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
||||
SET r.group_id = rel.group_id,
|
||||
r.run_id = rel.run_id,
|
||||
r.created_at = rel.created_at,
|
||||
r.expired_at = rel.expired_at,
|
||||
r.connect_strength = rel.connect_strength
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS e, score
|
||||
WHERE e.name_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR e.group_id = $group_id)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
# Embedding-based search: cosine similarity on Statement.statement_embedding
|
||||
STATEMENT_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS s, score
|
||||
WHERE s.statement_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR s.group_id = $group_id)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
|
||||
CHUNK_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS c, score
|
||||
WHERE c.chunk_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR c.group_id = $group_id)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
# 查询实体名称包含指定字符串的实体
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
WHERE ($group_id IS NULL OR e.group_id = $group_id)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
e.apply_id AS apply_id,
|
||||
e.user_id AS user_id,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.entity_idx AS entity_idx,
|
||||
e.statement_id AS statement_id,
|
||||
e.description AS description,
|
||||
e.aliases AS aliases,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.fact_summary AS fact_summary,
|
||||
e.connect_strength AS connect_strength,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT c.id) AS chunk_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.sequence_number AS sequence_number,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||
|
||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
|
||||
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
# # 同组group_id下按name contains召回补充
|
||||
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND toLower(e.name) CONTAINS toLower(row.name)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE ($group_id IS NULL OR d.group_id = $group_id)
|
||||
AND d.id = $dialog_id
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_CHUNK_BY_CHUNK_ID = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE ($group_id IS NULL OR c.group_id = $group_id)
|
||||
AND c.id = $chunk_id
|
||||
RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
c.created_at AS created_at,
|
||||
c.expired_at AS expired_at,
|
||||
c.sequence_number AS sequence_number
|
||||
ORDER BY c.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL = """
|
||||
MATCH (s:Statement)
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
|
||||
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
collect(DISTINCT s.id) AS statement_ids
|
||||
ORDER BY datetime(s.created_at) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
|
||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
|
||||
WHERE ($group_id IS NULL OR s.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR s.user_id = $user_id)
|
||||
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
|
||||
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
|
||||
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
|
||||
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||
RETURN s.id AS id,
|
||||
s.statement AS statement,
|
||||
s.group_id AS group_id,
|
||||
s.apply_id AS apply_id,
|
||||
s.user_id AS user_id,
|
||||
s.chunk_id AS chunk_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
score
|
||||
ORDER BY s.created_at DESC, score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_BY_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_G_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_L_CREATED_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_G_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
SEARCH_STATEMENTS_L_VALID_AT = """
|
||||
MATCH (n:Statement)
|
||||
WHERE ($group_id IS NULL OR n.group_id = $group_id)
|
||||
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
|
||||
AND ($user_id IS NULL OR n.user_id = $user_id)
|
||||
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
|
||||
RETURN n.id AS id,
|
||||
n.statement AS statement,
|
||||
n.group_id AS group_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
n.chunk_id AS chunk_id,
|
||||
n.created_at AS created_at,
|
||||
n.valid_at AS valid_at,
|
||||
n.invalid_at AS invalid_at,
|
||||
collect(DISTINCT n.id) AS statement_ids
|
||||
ORDER BY n.valid_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||
|
||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
|
||||
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
# # 同组group_id下按name contains召回补充
|
||||
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
|
||||
# UNWIND $rows AS row
|
||||
# MATCH (e:ExtractedEntity)
|
||||
# WHERE e.group_id = row.group_id
|
||||
# AND toLower(e.name) CONTAINS toLower(row.name)
|
||||
# RETURN row.id AS incoming_id,
|
||||
# e.id AS id,
|
||||
# e.name AS name,
|
||||
# e.group_id AS group_id,
|
||||
# e.entity_idx AS entity_idx,
|
||||
# e.entity_type AS entity_type,
|
||||
# e.description AS description,
|
||||
# e.statement_id AS statement_id,
|
||||
# e.aliases AS aliases,
|
||||
# e.name_embedding AS name_embedding,
|
||||
# e.fact_summary AS fact_summary,
|
||||
# e.connect_strength AS connect_strength,
|
||||
# e.created_at AS created_at,
|
||||
# e.expired_at AS expired_at
|
||||
# """
|
||||
|
||||
# 根据id修改句子的invalid_at的值
|
||||
UPDATE_STATEMENT_INVALID_AT = """
|
||||
MATCH (n:Statement {group_id: $group_id, id: $id})
|
||||
SET n.invalid_at = $new_invalid_at
|
||||
"""
|
||||
|
||||
# MemorySummary keyword search using fulltext index
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
|
||||
WHERE ($group_id IS NULL OR m.group_id = $group_id)
|
||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS m, score
|
||||
WHERE m.summary_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR m.group_id = $group_id)
|
||||
RETURN m.id AS id,
|
||||
m.name AS name,
|
||||
m.group_id AS group_id,
|
||||
m.dialog_id AS dialog_id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
MEMORY_SUMMARY_NODE_SAVE = """
|
||||
UNWIND $summaries AS summary
|
||||
MERGE (m:MemorySummary {id: summary.id})
|
||||
SET m += {
|
||||
id: summary.id,
|
||||
name: summary.name,
|
||||
group_id: summary.group_id,
|
||||
user_id: summary.user_id,
|
||||
apply_id: summary.apply_id,
|
||||
run_id: summary.run_id,
|
||||
created_at: summary.created_at,
|
||||
expired_at: summary.expired_at,
|
||||
dialog_id: summary.dialog_id,
|
||||
chunk_ids: summary.chunk_ids,
|
||||
content: summary.content,
|
||||
summary_embedding: summary.summary_embedding,
|
||||
config_id: summary.config_id
|
||||
}
|
||||
RETURN m.id AS uuid
|
||||
"""
|
||||
|
||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE = """
|
||||
UNWIND $edges AS e
|
||||
MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
|
||||
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
|
||||
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
|
||||
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
|
||||
SET r.group_id = e.group_id,
|
||||
r.run_id = e.run_id,
|
||||
r.created_at = e.created_at,
|
||||
r.expired_at = e.expired_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
185
api/app/repositories/neo4j/dialog_repository.py
Normal file
185
api/app/repositories/neo4j/dialog_repository.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""对话仓储模块
|
||||
|
||||
本模块提供对话节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
DialogRepository: 对话仓储,管理DialogueNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import DialogueNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
||||
"""对话仓储
|
||||
|
||||
管理对话节点的创建、查询、更新和删除操作。
|
||||
提供按group_id、user_id、ref_id等条件查询对话的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"Dialogue"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化对话仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "Dialogue")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> DialogueNode:
|
||||
"""将节点数据映射为对话实体
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
DialogueNode: 对话实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
return DialogueNode(**n)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据group_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
|
||||
"""根据user_id查询对话
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"user_id": user_id}, limit=limit)
|
||||
|
||||
async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]:
|
||||
"""根据ref_id查询对话
|
||||
|
||||
ref_id是外部对话系统的引用ID,通常是唯一的。
|
||||
|
||||
Args:
|
||||
ref_id: 引用ID
|
||||
|
||||
Returns:
|
||||
Optional[DialogueNode]: 找到的对话,如果不存在则返回None
|
||||
"""
|
||||
results = await self.find({"ref_id": ref_id}, limit=1)
|
||||
return results[0] if results else None
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据group_id和user_id查询对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
user_id: 用户ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "user_id": user_id},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_recent_dialogs(
|
||||
self,
|
||||
group_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""查询最近的对话
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
days: 查询最近多少天的对话
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表,按创建时间倒序排列
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据config_id查询对话
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def find_by_config_and_group(
|
||||
self,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[DialogueNode]:
|
||||
"""根据config_id和group_id查询对话
|
||||
|
||||
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[DialogueNode]: 对话列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"config_id": config_id, "group_id": group_id},
|
||||
limit=limit
|
||||
)
|
||||
339
api/app/repositories/neo4j/entity_repository.py
Normal file
339
api/app/repositories/neo4j/entity_repository.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""实体仓储模块
|
||||
|
||||
本模块提供实体节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
EntityRepository: 实体仓储,管理ExtractedEntityNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
"""实体仓储
|
||||
|
||||
管理实体节点的创建、查询、更新和删除操作。
|
||||
提供按类型、名称、向量相似度等条件查询实体的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"ExtractedEntity"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化实体仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "ExtractedEntity")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> ExtractedEntityNode:
|
||||
"""将节点数据映射为实体对象
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
ExtractedEntityNode: 实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据实体类型查询
|
||||
|
||||
Args:
|
||||
entity_type: 实体类型(如"Person", "Organization"等)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"entity_type": entity_type}, limit=limit)
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
"""根据group_id查询实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def find_by_name(
|
||||
self,
|
||||
name: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据名称查询实体
|
||||
|
||||
支持模糊匹配(CONTAINS)。
|
||||
|
||||
Args:
|
||||
name: 实体名称
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
where_clause = "n.name CONTAINS $name"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"name": name, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_related_entities(
|
||||
self,
|
||||
entity_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询相关实体
|
||||
|
||||
查询与指定实体有关系的其他实体。
|
||||
|
||||
Args:
|
||||
entity_id: 实体ID
|
||||
relation_type: 可选的关系类型过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 相关实体列表
|
||||
"""
|
||||
if relation_type:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
relation_type=relation_type,
|
||||
limit=limit
|
||||
)
|
||||
else:
|
||||
query = """
|
||||
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
|
||||
RETURN e2 as n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
entity_id=entity_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
where_clause = "n.name_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
|
||||
"""根据陈述句ID查询实体
|
||||
|
||||
查询从指定陈述句中提取的所有实体。
|
||||
|
||||
Args:
|
||||
statement_id: 陈述句ID
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"statement_id": statement_id})
|
||||
|
||||
async def find_strong_entities(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""查询强连接的实体
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 强连接的实体列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
|
||||
"""统计各类型实体的数量
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 实体类型到数量的映射
|
||||
"""
|
||||
query = """
|
||||
MATCH (n:ExtractedEntity {group_id: $group_id})
|
||||
RETURN n.entity_type as entity_type, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return {r["entity_type"]: r["count"] for r in results}
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[ExtractedEntityNode]:
|
||||
"""根据config_id查询实体
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[ExtractedEntityNode]: 实体列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索实体,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与实体名称向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含实体和相似度分数的字典列表
|
||||
每个字典包含: entity (ExtractedEntityNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.name_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"entity": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
216
api/app/repositories/neo4j/graph_saver.py
Normal file
216
api/app/repositories/neo4j/graph_saver.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from typing import List
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
ENTITY_RELATIONSHIP_SAVE,
|
||||
EXTRACTED_ENTITY_NODE_SAVE,
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
ENTITY_RELATIONSHIP_SAVE,
|
||||
EXTRACTED_ENTITY_NODE_SAVE,
|
||||
)
|
||||
from app.core.memory.models.graph_models import (
|
||||
DialogueNode,
|
||||
ChunkNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
)
|
||||
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save entities and their relationships using graph models"""
|
||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||
all_relationships = []
|
||||
|
||||
for edge in entity_entity_edges:
|
||||
relationship = {
|
||||
'source_id': edge.source,
|
||||
'target_id': edge.target,
|
||||
'predicate': edge.relation_type,
|
||||
'statement_id': edge.source_statement_id,
|
||||
'value': edge.relation_value,
|
||||
'statement': edge.statement,
|
||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||
'created_at': edge.created_at.isoformat(),
|
||||
'expired_at': edge.expired_at.isoformat(),
|
||||
'run_id': edge.run_id,
|
||||
'group_id': edge.group_id,
|
||||
'user_id': edge.user_id,
|
||||
'apply_id': edge.apply_id,
|
||||
}
|
||||
all_relationships.append(relationship)
|
||||
|
||||
# Save entities
|
||||
if all_entities:
|
||||
entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities)
|
||||
if entity_uuids:
|
||||
print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save entity nodes to Neo4j")
|
||||
else:
|
||||
print("No entity nodes to save")
|
||||
|
||||
# Create relationships
|
||||
if all_relationships:
|
||||
relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships)
|
||||
if relationship_uuids:
|
||||
print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j")
|
||||
else:
|
||||
print("Failed to save entity relationships to Neo4j")
|
||||
else:
|
||||
print("No entity relationships to save")
|
||||
|
||||
|
||||
async def save_chunk_nodes(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save chunk nodes using graph models"""
|
||||
if not chunk_nodes:
|
||||
print("No chunk nodes to save")
|
||||
return
|
||||
|
||||
chunk_uuids = await add_chunk_nodes(chunk_nodes, connector)
|
||||
if chunk_uuids:
|
||||
print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save chunk nodes to Neo4j")
|
||||
|
||||
|
||||
async def save_statement_chunk_edges(
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-chunk edges using graph models"""
|
||||
if not statement_chunk_edges:
|
||||
return
|
||||
|
||||
all_sc_edges = []
|
||||
for edge in statement_chunk_edges:
|
||||
all_sc_edges.append({
|
||||
"id": edge.id,
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"run_id": edge.run_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
})
|
||||
|
||||
try:
|
||||
await connector.execute_query(
|
||||
CHUNK_STATEMENT_EDGE_SAVE,
|
||||
chunk_statement_edges=all_sc_edges
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def save_statement_entity_edges(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-entity edges using graph models"""
|
||||
if not statement_entity_edges:
|
||||
print("No statement-entity edges to save")
|
||||
return
|
||||
|
||||
all_se_edges = []
|
||||
for edge in statement_entity_edges:
|
||||
edge_data = {
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"group_id": edge.group_id,
|
||||
"user_id": edge.user_id,
|
||||
"apply_id": edge.apply_id,
|
||||
"run_id": edge.run_id,
|
||||
"connect_strength": edge.connect_strength,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
}
|
||||
all_se_edges.append(edge_data)
|
||||
|
||||
if all_se_edges:
|
||||
try:
|
||||
await connector.execute_query(
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
relationships=all_se_edges
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
statement_nodes: List of StatementNode objects to save
|
||||
entity_nodes: List of ExtractedEntityNode objects to save
|
||||
entity_edges: List of EntityEntityEdge objects to save
|
||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Save all dialogue nodes in batch
|
||||
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
||||
if dialogue_uuids:
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
else:
|
||||
print("Failed to save dialogues to Neo4j")
|
||||
return False
|
||||
|
||||
# Save all chunk nodes in batch
|
||||
await save_chunk_nodes(chunk_nodes, connector)
|
||||
|
||||
# Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
||||
if statement_uuids:
|
||||
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save statement nodes to Neo4j")
|
||||
return False
|
||||
else:
|
||||
print("No statement nodes to save")
|
||||
|
||||
# Save entities and relationships
|
||||
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
||||
print("Successfully saved entities and relationships to Neo4j")
|
||||
|
||||
# Save new edges
|
||||
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
||||
await save_statement_entity_edges(statement_entity_edges, connector)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
584
api/app/repositories/neo4j/graph_search.py
Normal file
584
api/app/repositories/neo4j/graph_search.py
Normal file
@@ -0,0 +1,584 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
- Chunks: matches s.content CONTAINS q (from Statement nodes)
|
||||
- Summaries: matches ms.content CONTAINS q
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
group_id: Optional group filter
|
||||
limit: Max results per category
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
Returns:
|
||||
Dictionary with search results per category
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Build results dictionary
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
- Filters by group_id if provided
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
|
||||
# Get embedding for the query
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||
embedding = embeddings[0]
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
# Statements (embedding)
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
# Chunks (embedding)
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
# Entities
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
# Memory summaries
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
# Execute all queries in parallel
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
"statements": [],
|
||||
"chunks": [],
|
||||
"entities": [],
|
||||
"summaries": [],
|
||||
}
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
|
||||
- 保留并发控制与返回结构(incoming_id -> [db_entity_props...]);
|
||||
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
|
||||
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
|
||||
|
||||
返回:incoming_id -> [db_entity_props...]
|
||||
"""
|
||||
|
||||
if not entities:
|
||||
return {}
|
||||
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def _query_by_name(incoming: Dict[str, Any]) -> tuple[str, List[Dict[str, Any]]]:
|
||||
async with sem:
|
||||
inc_id = incoming.get("id") or "__unknown__"
|
||||
name = (incoming.get("name") or "").strip()
|
||||
if not name:
|
||||
return inc_id, []
|
||||
try:
|
||||
# 全文索引按名称检索(包含 CONTAINS 语义)
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name,
|
||||
group_id=group_id,
|
||||
limit=100,
|
||||
)
|
||||
except Exception:
|
||||
rows = []
|
||||
|
||||
# 可选本地类型过滤(若输入实体提供类型)
|
||||
typ = incoming.get("entity_type")
|
||||
if typ:
|
||||
try:
|
||||
rows = [r for r in rows if (r.get("entity_type") == typ)]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 注入 incoming_id 以保持兼容下游合并逻辑
|
||||
for r in rows:
|
||||
r["incoming_id"] = inc_id
|
||||
|
||||
# 简单的降级:若为空且允许 fallback,可按小写名再次查询
|
||||
if use_contains_fallback and not rows and name:
|
||||
try:
|
||||
rows = await connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
q=name.lower(),
|
||||
group_id=group_id,
|
||||
limit=100,
|
||||
)
|
||||
for r in rows:
|
||||
r["incoming_id"] = inc_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return inc_id, rows
|
||||
|
||||
tasks = [_query_by_name(e) for e in entities]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
merged: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for res in results:
|
||||
if isinstance(res, Exception):
|
||||
# 静默跳过单条失败
|
||||
continue
|
||||
inc_id, rows = res
|
||||
inc_id = inc_id or "__unknown__"
|
||||
merged.setdefault(inc_id, [])
|
||||
existing_ids = {x.get("id") for x in merged[inc_id]}
|
||||
for rec in rows:
|
||||
if rec.get("id") not in existing_ids:
|
||||
merged[inc_id].append(rec)
|
||||
return merged
|
||||
|
||||
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
valid_date=valid_date,
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
|
||||
- Matches dialogues with dialog_id
|
||||
- Optionally filters by group_id
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
group_id=group_id,
|
||||
dialog_id=dialog_id,
|
||||
limit=limit,
|
||||
)
|
||||
return {"dialogues": dialogues}
|
||||
|
||||
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
group_id=group_id,
|
||||
chunk_id=chunk_id,
|
||||
limit=limit,
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
group_id=group_id,
|
||||
apply_id=apply_id,
|
||||
user_id=user_id,
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
114
api/app/repositories/neo4j/neo4j_connector.py
Normal file
114
api/app/repositories/neo4j/neo4j_connector.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j连接器模块
|
||||
|
||||
本模块提供Neo4j图数据库的连接和查询功能。
|
||||
从 app/core/memory/src/database/neo4j_connector.py 迁移而来。
|
||||
|
||||
Classes:
|
||||
Neo4jConnector: Neo4j数据库连接器,提供异步查询接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Neo4jConnector:
|
||||
"""Neo4j数据库连接器
|
||||
|
||||
提供与Neo4j图数据库的连接和查询功能。
|
||||
使用异步驱动程序以支持高并发操作。
|
||||
|
||||
Attributes:
|
||||
driver: Neo4j异步驱动程序实例
|
||||
|
||||
Methods:
|
||||
close: 关闭数据库连接
|
||||
execute_query: 执行Cypher查询
|
||||
delete_group: 删除指定组的所有数据
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Neo4j连接器
|
||||
|
||||
从配置文件和环境变量中读取连接信息。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果NEO4J_PASSWORD环境变量未设置
|
||||
"""
|
||||
# 从全局配置和环境变量获取 Neo4j 配置
|
||||
uri = settings.NEO4J_URI
|
||||
username = settings.NEO4J_USERNAME
|
||||
password = settings.NEO4J_PASSWORD
|
||||
|
||||
if not password:
|
||||
raise RuntimeError(
|
||||
"NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running."
|
||||
)
|
||||
self.driver = AsyncGraphDatabase.driver(
|
||||
uri,
|
||||
auth=basic_auth(username, password)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接
|
||||
|
||||
释放数据库连接资源。应在应用程序关闭时调用。
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
"""执行Cypher查询
|
||||
|
||||
Args:
|
||||
query: Cypher查询语句
|
||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> results = await connector.execute_query(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name="Alice"
|
||||
... )
|
||||
"""
|
||||
result = await self.driver.execute_query(
|
||||
query,
|
||||
database="neo4j",
|
||||
**kwargs
|
||||
)
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def delete_group(self, group_id: str):
|
||||
"""删除指定组的所有数据
|
||||
|
||||
删除所有属于指定group_id的节点和边。
|
||||
这是一个危险操作,会永久删除数据。
|
||||
|
||||
Args:
|
||||
group_id: 要删除的组ID
|
||||
|
||||
Example:
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> await connector.delete_group("group_123")
|
||||
Group group_123 deleted.
|
||||
"""
|
||||
# 删除节点(DETACH DELETE会同时删除相关的边)
|
||||
await self.driver.execute_query(
|
||||
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
)
|
||||
# 删除独立的边(如果有的话)
|
||||
await self.driver.execute_query(
|
||||
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
|
||||
database="neo4j",
|
||||
group_id=group_id
|
||||
)
|
||||
print(f"Group {group_id} deleted.")
|
||||
319
api/app/repositories/neo4j/statement_repository.py
Normal file
319
api/app/repositories/neo4j/statement_repository.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""陈述句仓储模块
|
||||
|
||||
本模块提供陈述句节点的数据访问功能。
|
||||
|
||||
Classes:
|
||||
StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.core.memory.models.graph_models import StatementNode
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||
|
||||
|
||||
class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
"""陈述句仓储
|
||||
|
||||
管理陈述句节点的创建、查询、更新和删除操作。
|
||||
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,固定为"Statement"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化陈述句仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
super().__init__(connector, "Statement")
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> StatementNode:
|
||||
"""将节点数据映射为陈述句实体
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
StatementNode: 陈述句实体对象
|
||||
"""
|
||||
# 从查询结果中提取节点数据
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# 处理datetime字段
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
||||
if n.get('valid_at') and isinstance(n['valid_at'], str):
|
||||
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
|
||||
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
|
||||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||||
|
||||
# 处理temporal_info字段
|
||||
if isinstance(n.get('temporal_info'), dict):
|
||||
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
|
||||
elif not n.get('temporal_info'):
|
||||
# 如果没有temporal_info,创建一个默认的
|
||||
n['temporal_info'] = TemporalInfo()
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
|
||||
"""根据chunk_id查询陈述句
|
||||
|
||||
Args:
|
||||
chunk_id: 分块ID
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"chunk_id": chunk_id})
|
||||
|
||||
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
|
||||
"""根据group_id查询陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"group_id": group_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding(
|
||||
self,
|
||||
embedding: List[float],
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clause = "n.statement_embedding IS NOT NULL"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def search_by_keyword(
|
||||
self,
|
||||
keyword: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[StatementNode]:
|
||||
"""基于关键词搜索陈述句
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clause = "n.statement CONTAINS $keyword"
|
||||
if group_id:
|
||||
where_clause += " AND n.group_id = $group_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"keyword": keyword, "limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_by_temporal_range(
|
||||
self,
|
||||
group_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据时间范围查询陈述句
|
||||
|
||||
查询在指定时间范围内有效的陈述句。
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
where_clauses = ["n.group_id = $group_id"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("n.valid_at >= $start_date")
|
||||
params["start_date"] = start_date.isoformat()
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
|
||||
params["end_date"] = end_date.isoformat()
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
async def find_strong_statements(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""查询强连接的陈述句
|
||||
|
||||
Args:
|
||||
group_id: 组ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 强连接的陈述句列表
|
||||
"""
|
||||
return await self.find(
|
||||
{"group_id": group_id, "connect_strength": "Strong"},
|
||||
limit=limit
|
||||
)
|
||||
|
||||
async def find_by_config_id(
|
||||
self,
|
||||
config_id: str,
|
||||
limit: int = 100
|
||||
) -> List[StatementNode]:
|
||||
"""根据config_id查询陈述句
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[StatementNode]: 陈述句列表
|
||||
"""
|
||||
return await self.find({"config_id": config_id}, limit=limit)
|
||||
|
||||
async def search_by_embedding_with_config(
|
||||
self,
|
||||
embedding: List[float],
|
||||
config_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
min_score: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""基于向量相似度搜索陈述句,可选择按config_id过滤
|
||||
|
||||
使用余弦相似度计算查询向量与陈述句向量的相似度。
|
||||
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
|
||||
|
||||
Args:
|
||||
embedding: 查询向量
|
||||
config_id: 可选的配置ID过滤
|
||||
group_id: 可选的组ID过滤
|
||||
limit: 返回结果的最大数量
|
||||
min_score: 最小相似度分数阈值
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含陈述句和相似度分数的字典列表
|
||||
每个字典包含: statement (StatementNode), score (float)
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["n.statement_embedding IS NOT NULL"]
|
||||
params = {
|
||||
"embedding": embedding,
|
||||
"min_score": min_score,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if config_id:
|
||||
where_clauses.append("n.config_id = $config_id")
|
||||
params["config_id"] = config_id
|
||||
|
||||
if group_id:
|
||||
where_clauses.append("n.group_id = $group_id")
|
||||
params["group_id"] = group_id
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN n, score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
return [
|
||||
{
|
||||
"statement": self._map_to_entity(r),
|
||||
"score": r.get("score", 0.0)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
59
api/app/repositories/release_share_repository.py
Normal file
59
api/app/repositories/release_share_repository.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from app.models import ReleaseShare
|
||||
|
||||
|
||||
class ReleaseShareRepository:
|
||||
"""发布版本分享仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create(self, release_share: ReleaseShare) -> ReleaseShare:
|
||||
"""创建分享配置"""
|
||||
self.db.add(release_share)
|
||||
self.db.commit()
|
||||
self.db.refresh(release_share)
|
||||
return release_share
|
||||
|
||||
def get_by_id(self, share_id: uuid.UUID) -> Optional[ReleaseShare]:
|
||||
"""根据 ID 获取分享配置"""
|
||||
return self.db.get(ReleaseShare, share_id)
|
||||
|
||||
def get_by_release_id(self, release_id: uuid.UUID) -> Optional[ReleaseShare]:
|
||||
"""根据发布版本 ID 获取分享配置"""
|
||||
stmt = select(ReleaseShare).where(ReleaseShare.release_id == release_id)
|
||||
return self.db.scalars(stmt).first()
|
||||
|
||||
def get_by_share_token(self, share_token: str) -> Optional[ReleaseShare]:
|
||||
"""根据分享 token 获取分享配置"""
|
||||
stmt = select(ReleaseShare).where(ReleaseShare.share_token == share_token)
|
||||
return self.db.scalars(stmt).first()
|
||||
|
||||
def update(self, release_share: ReleaseShare) -> ReleaseShare:
|
||||
"""更新分享配置"""
|
||||
self.db.commit()
|
||||
self.db.refresh(release_share)
|
||||
return release_share
|
||||
|
||||
def delete(self, release_share: ReleaseShare) -> None:
|
||||
"""删除分享配置"""
|
||||
self.db.delete(release_share)
|
||||
self.db.commit()
|
||||
|
||||
def token_exists(self, share_token: str) -> bool:
|
||||
"""检查 token 是否已存在"""
|
||||
stmt = select(ReleaseShare.id).where(ReleaseShare.share_token == share_token)
|
||||
return self.db.scalars(stmt).first() is not None
|
||||
|
||||
def increment_view_count(self, share_id: uuid.UUID) -> None:
|
||||
"""增加访问次数(异步更新,不阻塞)"""
|
||||
from datetime import datetime
|
||||
stmt = select(ReleaseShare).where(ReleaseShare.id == share_id)
|
||||
share = self.db.scalars(stmt).first()
|
||||
if share:
|
||||
share.view_count += 1
|
||||
share.last_accessed_at = datetime.now()
|
||||
self.db.commit()
|
||||
167
api/app/repositories/tenant_repository.py
Normal file
167
api/app/repositories/tenant_repository.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_, func
|
||||
from typing import List, Optional
|
||||
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.schemas.tenant_schema import TenantCreate, TenantUpdate
|
||||
|
||||
|
||||
class TenantRepository:
|
||||
"""租户数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
|
||||
"""创建租户"""
|
||||
db_tenant = Tenants(
|
||||
name=tenant_data.name,
|
||||
id=uuid.uuid4(),
|
||||
description=tenant_data.description,
|
||||
is_active=tenant_data.is_active
|
||||
)
|
||||
self.db.add(db_tenant)
|
||||
self.db.flush()
|
||||
return db_tenant
|
||||
|
||||
def get_tenant_by_id(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""根据ID获取租户"""
|
||||
return self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
|
||||
|
||||
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
|
||||
"""根据名称获取租户"""
|
||||
return self.db.query(Tenants).filter(Tenants.name == name).first()
|
||||
|
||||
def get_tenants(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[Tenants]:
|
||||
"""获取租户列表"""
|
||||
query = self.db.query(Tenants)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Tenants.is_active == is_active)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Tenants.name.ilike(f"%{search}%"),
|
||||
Tenants.description.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
def count_tenants(
|
||||
self,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户数量"""
|
||||
query = self.db.query(func.count(Tenants.id))
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Tenants.is_active == is_active)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Tenants.name.ilike(f"%{search}%"),
|
||||
Tenants.description.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
return query.scalar()
|
||||
|
||||
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
|
||||
"""更新租户"""
|
||||
db_tenant = self.get_tenant_by_id(tenant_id)
|
||||
if not db_tenant:
|
||||
return None
|
||||
|
||||
for field, value in tenant_data.dict(exclude_unset=True).items():
|
||||
setattr(db_tenant, field, value)
|
||||
|
||||
self.db.flush()
|
||||
return db_tenant
|
||||
|
||||
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除租户"""
|
||||
db_tenant = self.get_tenant_by_id(tenant_id)
|
||||
if not db_tenant:
|
||||
return False
|
||||
|
||||
self.db.delete(db_tenant)
|
||||
return True
|
||||
|
||||
def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]:
|
||||
"""获取租户下的所有用户"""
|
||||
query = self.db.query(User).filter(User.tenant_id == tenant_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
return query.all()
|
||||
|
||||
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""获取用户所属的租户"""
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.tenant_id:
|
||||
return None
|
||||
|
||||
return self.get_tenant_by_id(user.tenant_id)
|
||||
|
||||
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""将用户分配给租户"""
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 验证租户存在
|
||||
tenant = self.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
return False
|
||||
|
||||
user.tenant_id = tenant_id
|
||||
self.db.flush()
|
||||
return True
|
||||
|
||||
def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
return query.scalar()
|
||||
|
||||
|
||||
# 便利函数,保持向后兼容
|
||||
def create_tenant(db: Session, tenant_data: TenantCreate) -> Tenants:
|
||||
"""创建租户"""
|
||||
return TenantRepository(db).create_tenant(tenant_data)
|
||||
|
||||
def get_tenant_by_id(db: Session, tenant_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""根据ID获取租户"""
|
||||
return TenantRepository(db).get_tenant_by_id(tenant_id)
|
||||
|
||||
def get_tenant_by_name(db: Session, name: str) -> Optional[Tenants]:
|
||||
"""根据名称获取租户"""
|
||||
return TenantRepository(db).get_tenant_by_name(name)
|
||||
|
||||
def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]:
|
||||
"""获取租户列表"""
|
||||
return TenantRepository(db).get_tenants(skip=skip, limit=limit)
|
||||
|
||||
def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""获取用户所属的租户"""
|
||||
return TenantRepository(db).get_user_tenant(user_id)
|
||||
|
||||
def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]:
|
||||
"""获取租户下的所有用户"""
|
||||
return TenantRepository(db).get_tenant_users(tenant_id)
|
||||
322
api/app/repositories/user_repository.py
Normal file
322
api/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,322 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_, func
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.schemas.user_schema import UserCreate, UserUpdate
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""用户数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_user_by_id(self, user_id: uuid.UUID) -> Optional[User]:
|
||||
"""根据ID获取用户"""
|
||||
db_logger.debug(f"根据ID查询用户: user_id={user_id}")
|
||||
|
||||
try:
|
||||
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id).first()
|
||||
if user:
|
||||
db_logger.debug(f"用户查询成功: {user.username} (ID: {user_id})")
|
||||
else:
|
||||
db_logger.debug(f"用户不存在: user_id={user_id}")
|
||||
return user
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询用户失败: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""根据邮箱获取用户"""
|
||||
db_logger.debug(f"根据邮箱查询用户: email={email}")
|
||||
|
||||
try:
|
||||
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.email == email).first()
|
||||
if user:
|
||||
db_logger.debug(f"用户查询成功: {user.username} (email: {email})")
|
||||
else:
|
||||
db_logger.debug(f"用户不存在: email={email}")
|
||||
return user
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据邮箱查询用户失败: email={email} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""根据用户名获取用户"""
|
||||
db_logger.debug(f"根据用户名查询用户: username={username}")
|
||||
|
||||
try:
|
||||
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.username == username).first()
|
||||
if user:
|
||||
db_logger.debug(f"用户查询成功: {user.username} (ID: {user.id})")
|
||||
else:
|
||||
db_logger.debug(f"用户不存在: username={username}")
|
||||
return user
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据用户名查询用户失败: username={username} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_superuser(self) -> Optional[User]:
|
||||
"""获取超级用户"""
|
||||
db_logger.debug("查询超级用户")
|
||||
|
||||
try:
|
||||
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).first()
|
||||
if user:
|
||||
db_logger.debug(f"超级用户查询成功: {user.username}")
|
||||
else:
|
||||
db_logger.debug("超级用户不存在")
|
||||
return user
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询超级用户失败: {str(e)}")
|
||||
raise
|
||||
def check_superuser_only(self) -> bool:
|
||||
"""检查是否只有一个超级用户"""
|
||||
db_logger.debug("检查是否只有一个超级用户")
|
||||
|
||||
try:
|
||||
count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).count()
|
||||
return count == 1
|
||||
except Exception as e:
|
||||
db_logger.error(f"检查超级用户数量失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
user_data: UserCreate,
|
||||
hashed_password: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
is_superuser: bool = False
|
||||
) -> User:
|
||||
"""创建用户"""
|
||||
db_logger.debug(f"创建用户记录: username={user_data.username}, email={user_data.email}, is_superuser={is_superuser}")
|
||||
|
||||
try:
|
||||
db_user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
tenant_id=tenant_id,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
self.db.add(db_user)
|
||||
self.db.flush()
|
||||
db_logger.info(f"用户记录创建成功: {user_data.username} (email: {user_data.email})")
|
||||
return db_user
|
||||
except Exception as e:
|
||||
db_logger.error(f"创建用户记录失败: username={user_data.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_user(self, user_id: uuid.UUID, user_data: UserUpdate) -> Optional[User]:
|
||||
"""更新用户"""
|
||||
db_logger.debug(f"更新用户: user_id={user_id}")
|
||||
|
||||
try:
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
db_logger.debug(f"用户不存在: user_id={user_id}")
|
||||
return None
|
||||
|
||||
for field, value in user_data.dict(exclude_unset=True).items():
|
||||
setattr(user, field, value)
|
||||
|
||||
self.db.flush()
|
||||
db_logger.info(f"用户更新成功: {user.username}")
|
||||
return user
|
||||
except Exception as e:
|
||||
db_logger.error(f"更新用户失败: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_user(self, user_id: uuid.UUID) -> bool:
|
||||
"""删除用户"""
|
||||
db_logger.debug(f"删除用户: user_id={user_id}")
|
||||
|
||||
try:
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
db_logger.debug(f"用户不存在: user_id={user_id}")
|
||||
return False
|
||||
|
||||
self.db.delete(user)
|
||||
self.db.flush()
|
||||
db_logger.info(f"用户删除成功: {user.username}")
|
||||
return True
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除用户失败: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_users_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[User]:
|
||||
"""获取租户下的用户列表"""
|
||||
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.username.ilike(f"%{search}%"),
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
||||
return users
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def count_users_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
try:
|
||||
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.username.ilike(f"%{search}%"),
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
return query.scalar()
|
||||
except Exception as e:
|
||||
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_superusers_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = True
|
||||
) -> List[User]:
|
||||
"""获取租户下的超管用户列表"""
|
||||
db_logger.debug(f"查询租户超管用户: tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
query = self.db.query(User).options(joinedload(User.tenant)).filter(
|
||||
and_(
|
||||
User.tenant_id == tenant_id,
|
||||
User.is_superuser == True
|
||||
)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
users = query.all()
|
||||
db_logger.debug(f"租户超管用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
||||
return users
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询租户超管用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""将用户分配给租户"""
|
||||
db_logger.debug(f"分配用户到租户: user_id={user_id}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
db_logger.debug(f"用户不存在: user_id={user_id}")
|
||||
return False
|
||||
|
||||
# 验证租户存在
|
||||
tenant = self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
|
||||
if not tenant:
|
||||
db_logger.debug(f"租户不存在: tenant_id={tenant_id}")
|
||||
return False
|
||||
|
||||
user.tenant_id = tenant_id
|
||||
self.db.flush()
|
||||
db_logger.info(f"用户分配成功: user={user.username}, tenant={tenant.name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
db_logger.error(f"分配用户到租户失败: user_id={user_id}, tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_users_without_tenant(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None
|
||||
) -> List[User]:
|
||||
"""获取没有租户的用户列表"""
|
||||
try:
|
||||
query = self.db.query(User).filter(User.tenant_id.is_(None))
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询无租户用户失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 便利函数,保持向后兼容
|
||||
def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]:
|
||||
"""根据ID获取用户"""
|
||||
return UserRepository(db).get_user_by_id(user_id)
|
||||
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""根据邮箱获取用户"""
|
||||
return UserRepository(db).get_user_by_email(email)
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""根据用户名获取用户"""
|
||||
return UserRepository(db).get_user_by_username(username)
|
||||
|
||||
def get_superuser(db: Session) -> Optional[User]:
|
||||
"""获取超级用户"""
|
||||
return UserRepository(db).get_superuser()
|
||||
|
||||
def check_superuser_only(db: Session) -> Optional[User]:
|
||||
"""检查是否只有一个超级用户"""
|
||||
return UserRepository(db).check_superuser_only()
|
||||
|
||||
def create_user(
|
||||
db: Session,
|
||||
user: UserCreate,
|
||||
hashed_password: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
is_superuser: bool = False
|
||||
) -> User:
|
||||
"""创建用户(函数式接口)"""
|
||||
repo = UserRepository(db)
|
||||
return repo.create_user(user, hashed_password, tenant_id, is_superuser)
|
||||
|
||||
|
||||
def get_superusers_by_tenant(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = True
|
||||
) -> List[User]:
|
||||
"""获取租户下的超管用户列表(函数式接口)"""
|
||||
repo = UserRepository(db)
|
||||
return repo.get_superusers_by_tenant(tenant_id, is_active)
|
||||
134
api/app/repositories/workspace_invite_repository.py
Normal file
134
api/app/repositories/workspace_invite_repository.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
from typing import List, Optional
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.models.workspace_model import WorkspaceInvite, InviteStatus
|
||||
from app.schemas.workspace_schema import WorkspaceInviteCreate
|
||||
|
||||
|
||||
class WorkspaceInviteRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_invite(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
token_hash: str,
|
||||
created_by_user_id: uuid.UUID
|
||||
) -> WorkspaceInvite:
|
||||
"""创建工作空间邀请"""
|
||||
expires_at = datetime.datetime.now() + datetime.timedelta(days=invite_data.expires_in_days)
|
||||
|
||||
db_invite = WorkspaceInvite(
|
||||
workspace_id=workspace_id,
|
||||
email=invite_data.email,
|
||||
role=invite_data.role,
|
||||
token_hash=token_hash,
|
||||
status=InviteStatus.pending,
|
||||
expires_at=expires_at,
|
||||
created_by_user_id=created_by_user_id
|
||||
)
|
||||
|
||||
self.db.add(db_invite)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_invite)
|
||||
return db_invite
|
||||
|
||||
def get_invite_by_token_hash(self, token_hash: str) -> Optional[WorkspaceInvite]:
|
||||
"""根据令牌哈希获取邀请"""
|
||||
return self.db.query(WorkspaceInvite).filter(
|
||||
WorkspaceInvite.token_hash == token_hash
|
||||
).first()
|
||||
|
||||
def get_invite_by_id(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]:
|
||||
"""根据ID获取邀请"""
|
||||
return self.db.query(WorkspaceInvite).filter(
|
||||
WorkspaceInvite.id == invite_id
|
||||
).first()
|
||||
|
||||
def get_workspace_invites(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[WorkspaceInvite]:
|
||||
"""获取工作空间的邀请列表"""
|
||||
query = self.db.query(WorkspaceInvite).filter(
|
||||
WorkspaceInvite.workspace_id == workspace_id
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.filter(WorkspaceInvite.status == status)
|
||||
|
||||
return query.order_by(WorkspaceInvite.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
def get_pending_invite_by_email_and_workspace(
|
||||
self,
|
||||
email: str,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[WorkspaceInvite]:
|
||||
"""获取指定邮箱在指定工作空间的待处理邀请"""
|
||||
return self.db.query(WorkspaceInvite).filter(
|
||||
and_(
|
||||
WorkspaceInvite.email == email,
|
||||
WorkspaceInvite.workspace_id == workspace_id,
|
||||
WorkspaceInvite.status == InviteStatus.pending
|
||||
)
|
||||
).first()
|
||||
|
||||
def update_invite_status(
|
||||
self,
|
||||
invite_id: uuid.UUID,
|
||||
status: InviteStatus,
|
||||
accepted_at: Optional[datetime.datetime] = None
|
||||
) -> Optional[WorkspaceInvite]:
|
||||
"""更新邀请状态"""
|
||||
invite = self.get_invite_by_id(invite_id)
|
||||
if invite:
|
||||
invite.status = status
|
||||
if accepted_at:
|
||||
invite.accepted_at = accepted_at
|
||||
invite.updated_at = datetime.datetime.now()
|
||||
self.db.commit()
|
||||
self.db.refresh(invite)
|
||||
return invite
|
||||
|
||||
def revoke_invite(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]:
|
||||
"""撤销邀请"""
|
||||
return self.update_invite_status(invite_id, InviteStatus.revoked)
|
||||
|
||||
def expire_old_invites(self) -> int:
|
||||
"""将过期的邀请标记为已过期"""
|
||||
now = datetime.datetime.now()
|
||||
expired_count = self.db.query(WorkspaceInvite).filter(
|
||||
and_(
|
||||
WorkspaceInvite.status == InviteStatus.pending,
|
||||
WorkspaceInvite.expires_at < now
|
||||
)
|
||||
).update(
|
||||
{
|
||||
WorkspaceInvite.status: InviteStatus.expired,
|
||||
WorkspaceInvite.updated_at: now
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
return expired_count
|
||||
|
||||
def count_workspace_invites(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
status: Optional[InviteStatus] = None
|
||||
) -> int:
|
||||
"""统计工作空间邀请数量"""
|
||||
query = self.db.query(WorkspaceInvite).filter(
|
||||
WorkspaceInvite.workspace_id == workspace_id
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.filter(WorkspaceInvite.status == status)
|
||||
|
||||
return query.count()
|
||||
383
api/app/repositories/workspace_repository.py
Normal file
383
api/app/repositories/workspace_repository.py
Normal file
@@ -0,0 +1,383 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from app.models.user_model import User
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class WorkspaceRepository:
|
||||
"""工作空间数据访问层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
||||
"""创建工作空间"""
|
||||
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
db_workspace = Workspace(
|
||||
name=workspace_data.name,
|
||||
description=workspace_data.description,
|
||||
icon=workspace_data.icon,
|
||||
iconType=workspace_data.iconType,
|
||||
storage_type=workspace_data.storage_type,
|
||||
llm=workspace_data.llm,
|
||||
embedding=workspace_data.embedding,
|
||||
rerank=workspace_data.rerank,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
self.db.add(db_workspace)
|
||||
self.db.flush()
|
||||
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||
return db_workspace
|
||||
except Exception as e:
|
||||
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
|
||||
"""根据ID获取工作空间"""
|
||||
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
db_logger.debug(f"工作空间查询成功: {workspace.name} (ID: {workspace_id})")
|
||||
else:
|
||||
db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
return workspace
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据ID查询工作空间失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_workspace_models_configs(self, workspace_id: uuid.UUID) -> Optional[dict]:
|
||||
"""根据workspace_id获取模型配置(llm, embedding, rerank)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
configs = {
|
||||
"llm": workspace.llm,
|
||||
"embedding": workspace.embedding,
|
||||
"rerank": workspace.rerank
|
||||
}
|
||||
db_logger.debug(
|
||||
f"工作空间模型配置查询成功: workspace_id={workspace_id}, "
|
||||
f"llm={configs['llm']}, embedding={configs['embedding']}, rerank={configs['rerank']}"
|
||||
)
|
||||
return configs
|
||||
else:
|
||||
db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间模型配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
|
||||
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
|
||||
|
||||
try:
|
||||
# 首先获取用户信息以获取 tenant_id
|
||||
from app.models.user_model import User
|
||||
user = self.db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
db_logger.warning(f"用户不存在: user_id={user_id}")
|
||||
return []
|
||||
|
||||
if user.is_superuser:
|
||||
# 超级用户获取对应tenantid所有工作空间
|
||||
workspaces = (
|
||||
self.db.query(Workspace)
|
||||
.filter(Workspace.tenant_id == user.tenant_id)
|
||||
.filter(Workspace.is_active == True)
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
# 获取用户作为成员的工作空间
|
||||
member_workspaces = (
|
||||
self.db.query(Workspace)
|
||||
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
|
||||
.filter(WorkspaceMember.user_id == user_id)
|
||||
.filter(Workspace.is_active == True)
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
|
||||
return member_workspaces
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询用户工作空间失败: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取租户的所有工作空间"""
|
||||
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
workspaces = (
|
||||
self.db.query(Workspace)
|
||||
.filter(Workspace.tenant_id == tenant_id)
|
||||
.filter(Workspace.is_active == True)
|
||||
.all()
|
||||
)
|
||||
db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}")
|
||||
return workspaces
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||
"""添加工作空间成员"""
|
||||
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
||||
|
||||
try:
|
||||
db_member = WorkspaceMember(
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
role=role
|
||||
)
|
||||
self.db.add(db_member)
|
||||
self.db.flush()
|
||||
db_logger.info(f"工作空间成员添加成功: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
||||
return db_member
|
||||
except Exception as e:
|
||||
db_logger.error(f"添加工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
"""获取工作空间成员"""
|
||||
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.user_id == user_id,
|
||||
WorkspaceMember.workspace_id == workspace_id,
|
||||
WorkspaceMember.is_active == True,
|
||||
).first()
|
||||
if member:
|
||||
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_members_by_workspace(self, workspace_id: uuid.UUID) -> List[WorkspaceMember]:
|
||||
"""按工作空间获取成员列表,并预加载 user 与 workspace 关系"""
|
||||
db_logger.debug(f"查询工作空间的成员列表: workspace_id={workspace_id}")
|
||||
try:
|
||||
members = (
|
||||
self.db.query(WorkspaceMember)
|
||||
.join(User, WorkspaceMember.user_id == User.id)
|
||||
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
|
||||
.filter(WorkspaceMember.workspace_id == workspace_id)
|
||||
.filter(WorkspaceMember.is_active == True)
|
||||
.filter(User.is_active == True)
|
||||
.all()
|
||||
)
|
||||
db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}")
|
||||
return members
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
|
||||
"""按成员ID获取工作空间成员,并预加载 user 与 workspace 关系"""
|
||||
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
|
||||
try:
|
||||
member = (
|
||||
self.db.query(WorkspaceMember)
|
||||
.join(User, WorkspaceMember.user_id == User.id)
|
||||
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
|
||||
.filter(WorkspaceMember.id == member_id)
|
||||
.filter(WorkspaceMember.is_active == True)
|
||||
.filter(User.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if member:
|
||||
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"成员不存在: member_id={member_id}")
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.workspace_id == workspace_id,
|
||||
WorkspaceMember.user_id == user_id,
|
||||
WorkspaceMember.is_active == True,
|
||||
).first()
|
||||
if not member:
|
||||
return None
|
||||
member.role = role
|
||||
self.db.commit()
|
||||
self.db.refresh(member)
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"更新成员角色失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def deactivate_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.workspace_id == workspace_id,
|
||||
WorkspaceMember.user_id == user_id,
|
||||
WorkspaceMember.is_active == True,
|
||||
).first()
|
||||
if not member:
|
||||
return None
|
||||
member.is_active = False
|
||||
self.db.commit()
|
||||
self.db.refresh(member)
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.id == member_id,
|
||||
WorkspaceMember.is_active == True,
|
||||
).first()
|
||||
if not member:
|
||||
return None
|
||||
member.is_active = False
|
||||
self.db.commit()
|
||||
self.db.refresh(member)
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.id == id,
|
||||
WorkspaceMember.is_active == True,
|
||||
).first()
|
||||
if not member:
|
||||
return None
|
||||
member.role = role
|
||||
self.db.commit()
|
||||
self.db.refresh(member)
|
||||
return member
|
||||
except Exception as e:
|
||||
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
|
||||
raise
|
||||
|
||||
# 保持向后兼容的函数
|
||||
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspace_by_id(workspace_id)
|
||||
|
||||
|
||||
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_user(user_id)
|
||||
|
||||
|
||||
def get_workspaces_by_tenant(db: Session, tenant_id: uuid.UUID) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_tenant(tenant_id)
|
||||
|
||||
|
||||
def get_member_in_workspace(db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID) -> WorkspaceMember | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_member(user_id, workspace_id)
|
||||
|
||||
|
||||
def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.create_workspace(workspace, tenant_id)
|
||||
|
||||
|
||||
def add_member_to_workspace(
|
||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||
) -> WorkspaceMember:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.add_member(workspace_id, user_id, role)
|
||||
|
||||
|
||||
def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_members_by_workspace(workspace_id)
|
||||
|
||||
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_member_by_id(member_id)
|
||||
|
||||
def update_member_role_in_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role(workspace_id, user_id, role)
|
||||
|
||||
def remove_member_from_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.deactivate_member(workspace_id, user_id)
|
||||
|
||||
def remove_member_from_workspace_by_id(
|
||||
db: Session,
|
||||
member_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.delete_member_by_id(member_id)
|
||||
|
||||
|
||||
def update_member_role_by_id(
|
||||
db: Session,
|
||||
id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role_by_id(id, role)
|
||||
|
||||
|
||||
def get_workspace_models_configs(db: Session, workspace_id: uuid.UUID) -> Optional[dict]:
|
||||
"""根据workspace_id获取模型配置(llm, embedding, rerank)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||
|
||||
Example:
|
||||
>>> configs = get_workspace_models_configs(db, workspace_id)
|
||||
>>> if configs:
|
||||
>>> print(f"LLM: {configs['llm']}")
|
||||
>>> print(f"Embedding: {configs['embedding']}")
|
||||
>>> print(f"Rerank: {configs['rerank']}")
|
||||
"""
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspace_models_configs(workspace_id)
|
||||
Reference in New Issue
Block a user