Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""仓储模块
本模块提供统一的数据访问层包括PostgreSQL和Neo4j的仓储实现。
Classes:
RepositoryFactory: 仓储工厂,统一管理所有数据库的仓储实例
"""
from typing import Optional
from sqlalchemy.orm import Session
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.dialog_repository import DialogRepository
from app.repositories.neo4j.statement_repository import StatementRepository
from app.repositories.neo4j.entity_repository import EntityRepository
from app.repositories.user_repository import UserRepository
from app.repositories.workspace_repository import WorkspaceRepository
from app.repositories.app_repository import AppRepository
class RepositoryFactory:
"""仓储工厂 - 统一管理所有数据库的仓储
这个工厂类提供了获取各种仓储实例的统一接口。
支持Neo4j图数据库和PostgreSQL关系数据库的仓储。
Attributes:
neo4j_connector: Neo4j连接器实例可选
db_session: SQLAlchemy数据库会话可选
Example:
>>> # 创建工厂实例
>>> factory = RepositoryFactory(
... neo4j_connector=Neo4jConnector(),
... db_session=db_session
... )
>>>
>>> # 获取Neo4j仓储
>>> dialog_repo = factory.get_dialog_repository()
>>> statement_repo = factory.get_statement_repository()
>>>
>>> # 获取PostgreSQL仓储
>>> knowledge_repo = factory.get_knowledge_repository()
"""
def __init__(
self,
neo4j_connector: Optional[Neo4jConnector] = None,
db_session: Optional[Session] = None
):
"""初始化仓储工厂
Args:
neo4j_connector: Neo4j连接器实例可选
db_session: SQLAlchemy数据库会话可选
"""
self.neo4j_connector = neo4j_connector
self.db_session = db_session
# ==================== Neo4j 仓储 ====================
def get_dialog_repository(self) -> DialogRepository:
"""获取对话仓储
Returns:
DialogRepository: 对话仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return DialogRepository(self.neo4j_connector)
def get_statement_repository(self) -> StatementRepository:
"""获取陈述句仓储
Returns:
StatementRepository: 陈述句仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return StatementRepository(self.neo4j_connector)
def get_entity_repository(self) -> EntityRepository:
"""获取实体仓储
Returns:
EntityRepository: 实体仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return EntityRepository(self.neo4j_connector)
# ==================== PostgreSQL 仓储 ====================
# 注意现有的PostgreSQL仓储保持不变这里只是提供统一的访问接口
# 部分仓储如knowledge_repository、document_repository使用函数式接口
# 部分仓储如user_repository、workspace_repository使用类接口
def get_user_repository(self) -> UserRepository:
"""获取用户仓储
Returns:
UserRepository: 用户仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return UserRepository(self.db_session)
def get_workspace_repository(self) -> WorkspaceRepository:
"""获取工作空间仓储
Returns:
WorkspaceRepository: 工作空间仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return WorkspaceRepository(self.db_session)
def get_app_repository(self) -> AppRepository:
"""获取应用仓储
Returns:
AppRepository: 应用仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return AppRepository(self.db_session)
def get_db_session(self) -> Session:
"""获取数据库会话
用于访问函数式仓储如knowledge_repository、document_repository
Returns:
Session: SQLAlchemy数据库会话
Raises:
ValueError: 如果数据库会话未初始化
Example:
>>> factory = RepositoryFactory(db_session=session)
>>> db = factory.get_db_session()
>>> # 使用函数式仓储
>>> from app.repositories import knowledge_repository
>>> knowledges = knowledge_repository.get_knowledges_paginated(db, [], 1, 10)
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return self.db_session
__all__ = [
'RepositoryFactory',
]

View File

@@ -0,0 +1,138 @@
"""API Key Repository"""
from sqlalchemy.orm import Session
from sqlalchemy import select, func, and_
from typing import Optional, List, Tuple
import uuid
import datetime
from app.models.api_key_model import ApiKey, ApiKeyLog
from app.schemas import api_key_schema
class ApiKeyRepository:
"""API Key 数据访问层"""
@staticmethod
def create(db: Session, api_key_data: dict) -> ApiKey:
"""创建 API Key"""
api_key = ApiKey(**api_key_data)
db.add(api_key)
db.flush()
return api_key
@staticmethod
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ApiKey]:
"""根据 ID 获取 API Key"""
return db.get(ApiKey, api_key_id)
@staticmethod
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
"""根据哈希值获取 API Key"""
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
return db.scalars(stmt).first()
@staticmethod
def list_by_workspace(
db: Session,
workspace_id: uuid.UUID,
query: api_key_schema.ApiKeyQuery
) -> Tuple[List[ApiKey], int]:
"""列出工作空间的 API Keys"""
stmt = select(ApiKey).where(ApiKey.workspace_id == workspace_id)
# 过滤条件
if query.type:
stmt = stmt.where(ApiKey.type == query.type)
if query.is_active is not None:
stmt = stmt.where(ApiKey.is_active == query.is_active)
if query.resource_id:
stmt = stmt.where(ApiKey.resource_id == query.resource_id)
# 总数
count_stmt = select(func.count()).select_from(stmt.subquery())
total = db.execute(count_stmt).scalar()
# 分页
stmt = stmt.order_by(ApiKey.created_at.desc())
stmt = stmt.offset((query.page - 1) * query.pagesize).limit(query.pagesize)
items = db.scalars(stmt).all()
return list(items), total
@staticmethod
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey:
"""更新 API Key"""
api_key = db.get(ApiKey, api_key_id)
if api_key:
for key, value in update_data.items():
if value is not None:
setattr(api_key, key, value)
api_key.updated_at = datetime.datetime.now()
db.flush()
return api_key
@staticmethod
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
"""删除 API Key"""
api_key = db.get(ApiKey, api_key_id)
if api_key:
db.delete(api_key)
db.flush()
return True
return False
@staticmethod
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
"""更新使用统计"""
api_key = db.get(ApiKey, api_key_id)
if api_key:
api_key.usage_count += 1
api_key.quota_used += 1
api_key.last_used_at = datetime.datetime.now()
db.flush()
return True
return False
@staticmethod
def get_stats(db: Session, api_key_id: uuid.UUID) -> dict:
"""获取使用统计"""
api_key = db.get(ApiKey, api_key_id)
if not api_key:
return {}
# 今日请求数
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
today_count_stmt = select(func.count()).select_from(ApiKeyLog).where(
and_(
ApiKeyLog.api_key_id == api_key_id,
ApiKeyLog.created_at >= today_start
)
)
requests_today = db.execute(today_count_stmt).scalar() or 0
# 平均响应时间
avg_time_stmt = select(func.avg(ApiKeyLog.response_time)).where(
ApiKeyLog.api_key_id == api_key_id
)
avg_response_time = db.execute(avg_time_stmt).scalar()
return {
"total_requests": api_key.usage_count,
"requests_today": requests_today,
"quota_used": api_key.quota_used,
"quota_limit": api_key.quota_limit,
"last_used_at": api_key.last_used_at,
"avg_response_time": float(avg_response_time) if avg_response_time else None
}
class ApiKeyLogRepository:
"""API Key 日志数据访问层"""
@staticmethod
def create(db: Session, log_data: dict) -> ApiKeyLog:
"""创建日志"""
log = ApiKeyLog(**log_data)
db.add(log)
db.flush()
return log

View File

@@ -0,0 +1,30 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.models.app_model import App
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class AppRepository:
def __init__(self, db: Session):
self.db = db
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用"""
try:
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用")
return apps
except Exception as e:
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
raise
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_workspace_id(workspace_id)

View File

@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
"""基础仓储接口模块
本模块定义了通用的仓储接口适用于所有数据库类型PostgreSQL、Neo4j等
遵循仓储模式Repository Pattern提供统一的数据访问抽象。
Classes:
BaseRepository: 基础仓储接口定义CRUD操作的抽象方法
"""
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, List, Optional, Dict, Any
T = TypeVar('T')
class BaseRepository(ABC, Generic[T]):
"""基础仓储接口 - 适用于所有数据库类型
这是一个抽象基类定义了所有仓储必须实现的基本CRUD操作。
使用泛型T来支持不同的实体类型。
Type Parameters:
T: 实体类型通常是Pydantic模型或ORM模型
Methods:
create: 创建新实体
get_by_id: 根据ID获取实体
update: 更新现有实体
delete: 删除实体
find: 根据条件查询实体列表
"""
@abstractmethod
async def create(self, entity: T) -> T:
"""创建实体
Args:
entity: 要创建的实体对象
Returns:
T: 创建后的实体对象可能包含生成的ID等
Raises:
Exception: 创建失败时抛出异常
"""
pass
@abstractmethod
async def get_by_id(self, entity_id: str) -> Optional[T]:
"""根据ID获取实体
Args:
entity_id: 实体的唯一标识符
Returns:
Optional[T]: 找到的实体对象如果不存在则返回None
Raises:
Exception: 查询失败时抛出异常
"""
pass
@abstractmethod
async def update(self, entity: T) -> T:
"""更新实体
Args:
entity: 要更新的实体对象必须包含ID
Returns:
T: 更新后的实体对象
Raises:
Exception: 更新失败时抛出异常
"""
pass
@abstractmethod
async def delete(self, entity_id: str) -> bool:
"""删除实体
Args:
entity_id: 要删除的实体ID
Returns:
bool: 删除成功返回True否则返回False
Raises:
Exception: 删除失败时抛出异常
"""
pass
@abstractmethod
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
"""查询实体列表
Args:
filters: 查询条件字典,键为字段名,值为期望的值
limit: 返回结果的最大数量默认100
Returns:
List[T]: 符合条件的实体列表
Raises:
Exception: 查询失败时抛出异常
"""
pass

View File

@@ -0,0 +1,408 @@
# -*- coding: utf-8 -*-
"""数据配置Repository模块
本模块提供data_config表的数据访问层包括SQL查询构建和Neo4j Cypher查询。
从 app.core.memory.src.data_config_api.sql_queries 迁移而来。
Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作和查询构建
"""
from typing import Dict, Tuple, List
from sqlalchemy.orm import Session
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
# 表名常量
TABLE_NAME = "data_config"
class DataConfigRepository:
"""数据配置Repository
提供data_config表的数据访问方法包括
- SQL查询构建PostgreSQL
- Neo4j Cypher查询常量
"""
# ==================== Neo4j Cypher 查询常量 ====================
# Dialogue count by group
SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Chunk count by group
SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Statement count by group
SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# ExtractedEntity count by group
SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# All counts by label and total
SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count
UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN n.entity_idx AS entity_idx,
n.connect_strength AS connect_strength,
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
n.fact_summary AS fact_summary,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.id AS id
"""
# Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
r.group_id AS group_id,
r.apply_id AS apply_id,
r.user_id AS user_id,
elementId(r) AS rel_id,
startNode(r).id AS source_id,
endNode(r).id AS target_id,
r.predicate AS predicate,
r.statement_id AS statement_id,
r.statement AS statement
"""
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
WHERE n.group_id = $group_id
RETURN
{
entity_idx: n.entity_idx,
connect_strength: n.connect_strength,
description: n.description,
entity_type: n.entity_type,
name: n.name,
fact_summary: n.fact_summary,
id: n.id
} AS sourceNode,
{
rel_id: elementId(r),
source_id: startNode(r).id,
target_id: endNode(r).id,
predicate: r.predicate,
statement_id: r.statement_id,
statement: r.statement
} AS edge,
{
entity_idx: m.entity_idx,
connect_strength: m.connect_strength,
description: m.description,
entity_type: m.entity_type,
name: m.name,
fact_summary: m.fact_summary,
id: m.id
} AS targetNode
"""
# ==================== SQL 查询构建方法 ====================
@staticmethod
def build_insert(params: ConfigParamsCreate) -> Tuple[str, Dict]:
"""构建插入语句PostgreSQL 命名参数)
Args:
params: 配置参数创建模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建插入语句: config_name={params.config_name}, workspace_id={params.workspace_id}")
columns = [
"config_name",
"config_desc",
"workspace_id",
"llm_id",
"embedding_id",
"rerank_id",
"created_at",
]
placeholders = [
"%(config_name)s",
"%(config_desc)s",
"%(workspace_id)s::uuid",
"%(llm_id)s",
"%(embedding_id)s",
"%(rerank_id)s",
"timezone('Asia/Shanghai', now())",
]
query = f"INSERT INTO {TABLE_NAME} (" + ",".join(columns) + ") VALUES (" + ",".join(placeholders) + ")"
# 将 UUID 转换为字符串
workspace_id_str = str(params.workspace_id) if params.workspace_id else None
params_dict = {
"config_name": params.config_name,
"config_desc": params.config_desc,
"workspace_id": workspace_id_str,
"llm_id": params.llm_id,
"embedding_id": params.embedding_id,
"rerank_id": params.rerank_id,
}
return query, params_dict
@staticmethod
def build_update(update: ConfigUpdate) -> Tuple[str, Dict]:
"""构建基础配置更新语句PostgreSQL 命名参数)
Args:
update: 配置更新模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建更新语句: config_id={update.config_id}")
key_where = "config_id = %(config_id)s"
set_fields: List[str] = []
params: Dict = {
"config_id": update.config_id,
}
mapping = {
"config_name": "config_name",
"config_desc": "config_desc",
}
for api_field, db_col in mapping.items():
value = getattr(update, api_field)
if value is not None:
set_fields.append(f"{db_col} = %({api_field})s")
params[api_field] = value
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
if not set_fields:
raise ValueError("No fields to update")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
@staticmethod
def build_update_extracted(update: ConfigUpdateExtracted) -> Tuple[str, Dict]:
"""构建记忆萃取引擎配置更新语句PostgreSQL 命名参数)
Args:
update: 萃取配置更新模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建萃取配置更新语句: config_id={update.config_id}")
key_where = "config_id = %(config_id)s"
set_fields: List[str] = []
params: Dict = {
"config_id": update.config_id,
}
mapping = {
# 模型选择
"llm_id": "llm",
"embedding_id": "embedding",
"rerank_id": "rerank",
# 记忆萃取引擎
"enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise",
"enable_llm_disambiguation": "enable_llm_disambiguation",
"deep_retrieval": "deep_retrieval",
"t_type_strict": "t_type_strict",
"t_name_strict": "t_name_strict",
"t_overall": "t_overall",
"state": "state",
"chunker_strategy": "chunker_strategy",
# 句子提取
"statement_granularity": "statement_granularity",
"include_dialogue_context": "include_dialogue_context",
"max_context": "max_context",
# 剪枝配置
"pruning_enabled": "pruning_enabled",
"pruning_scene": "pruning_scene",
"pruning_threshold": "pruning_threshold",
# 自我反思配置
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
}
for api_field, db_col in mapping.items():
value = getattr(update, api_field)
if value is not None:
set_fields.append(f"{db_col} = %({api_field})s")
params[api_field] = value
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
if not set_fields:
raise ValueError("No fields to update")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
@staticmethod
def build_update_forget(update: ConfigUpdateForget) -> Tuple[str, Dict]:
"""构建遗忘引擎配置更新语句PostgreSQL 命名参数)
Args:
update: 遗忘配置更新模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建遗忘配置更新语句: config_id={update.config_id}")
key_where = "config_id = %(config_id)s"
set_fields: List[str] = []
params: Dict = {
"config_id": update.config_id,
}
mapping = {
# 遗忘引擎
"lambda_time": "lambda_time",
"lambda_mem": "lambda_mem",
# 由于 PostgreSQL 中 OFFSET 是保留字,需使用双引号包裹列名
"offset": '"offset"',
}
for api_field, db_col in mapping.items():
value = getattr(update, api_field)
if value is not None:
set_fields.append(f"{db_col} = %({api_field})s")
params[api_field] = value
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
if not set_fields:
raise ValueError("No fields to update")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
@staticmethod
def build_select_extracted(key: ConfigKey) -> Tuple[str, Dict]:
"""构建萃取配置查询语句通过主键查询某条配置PostgreSQL 命名参数)
Args:
key: 配置键模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建萃取配置查询语句: config_id={key.config_id}")
# f"SELECT statement_granularity, include_dialogue_context, max_context, "
query = (
f"SELECT llm_id, embedding_id, rerank_id, "
f"enable_llm_dedup_blockwise, enable_llm_disambiguation, deep_retrieval, "
f"t_type_strict, t_name_strict, t_overall, chunker_strategy, "
f"statement_granularity, include_dialogue_context, max_context, "
f"pruning_enabled, pruning_scene, pruning_threshold, "
f"enable_self_reflexion, iteration_period, reflexion_range, baseline "
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
)
params = {"config_id": key.config_id}
return query, params
@staticmethod
def build_select_forget(key: ConfigKey) -> Tuple[str, Dict]:
"""构建遗忘配置查询语句通过主键查询某条配置PostgreSQL 命名参数)
Args:
key: 配置键模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建遗忘配置查询语句: config_id={key.config_id}")
query = (
f"SELECT lambda_time, lambda_mem, \"offset\" " # 用双引号包裹保留字别名
f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
)
params = {"config_id": key.config_id}
return query, params
@staticmethod
def build_select_all(workspace_id = None) -> Tuple[str, Dict]:
"""构建查询所有配置参数的语句PostgreSQL 命名参数)
Args:
workspace_id: 工作空间IDUUID或字符串用于过滤查询结果
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
if workspace_id:
# 将 UUID 转换为字符串以便在 SQL 中使用
workspace_id_str = str(workspace_id) if workspace_id else None
query = f"SELECT * FROM {TABLE_NAME} WHERE workspace_id = %(workspace_id)s::uuid ORDER BY updated_at DESC NULLS LAST"
params = {"workspace_id": workspace_id_str}
else:
query = f"SELECT * FROM {TABLE_NAME} ORDER BY updated_at DESC NULLS LAST"
params = {}
return query, params
@staticmethod
def build_delete(key: ConfigParamsDelete) -> Tuple[str, Dict]:
"""构建删除语句通过配置ID删除PostgreSQL 命名参数)
Args:
key: 配置删除模型
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建删除语句: config_id={key.config_id}")
query = (
f"DELETE FROM {TABLE_NAME} WHERE config_id = %(config_id)s"
)
params = {"config_id": key.config_id}
return query, params

View File

@@ -0,0 +1,153 @@
import uuid
import datetime
from sqlalchemy.orm import Session
from app.models.document_model import Document
from app.schemas import document_schema
from app.core.logging_config import get_db_logger
# Obtain a dedicated logger for the database
db_logger = get_db_logger()
def get_documents_paginated(
db: Session,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
"""
Paged query document (with filtering and sorting)
"""
db_logger.debug(f"Query documents in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
try:
query = db.query(Document)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Calculate the total count (for pagination)
total = query.count()
db_logger.debug(f"Total number of document queries: {total}")
# sort
if orderby:
order_attr = getattr(Document, orderby, None)
if order_attr is not None:
if desc:
query = query.order_by(order_attr.desc())
else:
query = query.order_by(order_attr.asc())
db_logger.debug(f"sort: {orderby}, desc={desc}")
# pagination
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
db_logger.info(f"The document paging query has been successful: total={total}, Number of current page={len(items)}")
return total, [document_schema.Document.model_validate(item) for item in items]
except Exception as e:
db_logger.error(f"Querying document pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
raise
def create_document(db: Session, document: document_schema.DocumentCreate) -> Document:
db_logger.debug(f"Create a document record: file_name={document.file_name}")
try:
db_document = Document(**document.model_dump())
db.add(db_document)
db.commit()
db_logger.info(f"Document record created successfully: {document.file_name} (ID: {db_document.id})")
return db_document
except Exception as e:
db_logger.error(f"Failed to create a document record: title={document.file_name} - {str(e)}")
db.rollback()
raise
def get_document_by_id(db: Session, document_id: uuid.UUID) -> Document | None:
db_logger.debug(f"Query documents based on ID: document_id={document_id}")
try:
document = db.query(Document).filter(Document.id == document_id).first()
if document:
db_logger.debug(f"Document query successful: {document.file_name} (ID: {document_id})")
else:
db_logger.debug(f"Document does not exist: document_id={document_id}")
return document
except Exception as e:
db_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}")
raise
def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID) -> int:
"""
Reset the processing progress of all documents under the specified knowledge base
Args:
db: database session
kb_id: Knowledge Base ID
Returns:
int: Number of updated documents
"""
db_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id}")
try:
# Build update conditions
filters = [
Document.kb_id == kb_id
]
# Build updated data
update_data = {
Document.chunk_num: 0,
Document.progress: 0,
Document.progress_msg: "Pending",
Document.process_duration: 0,
Document.run: 0, # Reset run status
Document.updated_at: datetime.datetime.now()
}
# Perform batch update
result = db.query(Document).filter(*filters).update(
update_data,
synchronize_session=False
)
# commit transaction
db.commit()
db_logger.debug(f"Successfully reset the processing progress of all documents under the specified knowledge base: kb_id: {kb_id}")
return result
except Exception as e:
db.rollback()
db_logger.error(f"Failed to reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id} - {str(e)}")
raise
def delete_document_by_id(db: Session, document_id: uuid.UUID):
db_logger.debug(f"Delete document record: document_id={document_id}")
try:
# First, query the document information for logging purposes
document = db.query(Document).filter(Document.id == document_id).first()
if document:
file_name = document.file_name
else:
file_name = "unknown"
result = db.query(Document).filter(Document.id == document_id).delete()
db.commit()
if result > 0:
db_logger.info(f"Document record deleted successfully: {file_name} (ID: {document_id})")
else:
db_logger.warning(f"The document record does not exist, and cannot be deleted: document_id={document_id}")
except Exception as e:
db_logger.error(f"Failed to delete document record: document_id={document_id} - {str(e)}")
db.rollback()
raise

View File

@@ -0,0 +1,105 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.models.end_user_model import EndUser
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class EndUserRepository:
def __init__(self, db: Session):
self.db = db
def get_end_users_by_app_id(self, app_id: uuid.UUID) -> List[EndUser]:
"""根据应用ID查询宿主"""
try:
end_users = (
self.db.query(EndUser)
.filter(EndUser.app_id == app_id)
.all()
)
db_logger.info(f"成功查询应用 {app_id} 下的 {len(end_users)} 个宿主")
return end_users
except Exception as e:
self.db.rollback()
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
raise
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据 end_user_id 查询宿主"""
try:
end_user = (
self.db.query(EndUser)
.filter(EndUser.id == end_user_id)
.first()
)
if end_user:
db_logger.info(f"成功查询到宿主 {end_user_id}")
else:
db_logger.info(f"未找到宿主 {end_user_id}")
return end_user
except Exception as e:
self.db.rollback()
db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}")
raise
def get_or_create_end_user(
self,
app_id: uuid.UUID,
other_id: str,
original_user_id: Optional[str] = None
) -> EndUser:
"""获取或创建终端用户
Args:
app_id: 应用ID
other_id: 第三方ID
original_user_id: 原始用户ID (存储到 other_id)
"""
try:
# 尝试查找现有用户
end_user = (
self.db.query(EndUser)
.filter(
EndUser.app_id == app_id,
EndUser.other_id == other_id
)
.first()
)
if end_user:
db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}")
return end_user
# 创建新用户
end_user = EndUser(
app_id=app_id,
other_id=other_id
)
self.db.add(end_user)
self.db.commit()
self.db.refresh(end_user)
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}")
return end_user
except Exception as e:
self.db.rollback()
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
raise
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
"""根据应用ID查询宿主返回 EndUser ORM 列表)"""
repo = EndUserRepository(db)
end_users = repo.get_end_users_by_app_id(app_id)
return end_users
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据 end_user_id 查询对应宿主"""
repo = EndUserRepository(db)
end_user = repo.get_end_user_by_id(end_user_id)
return end_user

View File

@@ -0,0 +1,121 @@
import uuid
from sqlalchemy.orm import Session
from app.models.file_model import File
from app.schemas import file_schema
from app.core.logging_config import get_db_logger
# Obtain a dedicated logger for the database
db_logger = get_db_logger()
def get_files_paginated(
db: Session,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
"""
Paged query file (with filtering and sorting)
"""
db_logger.debug(f"Query file in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
try:
query = db.query(File)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Calculate the total count (for pagination)
total = query.count()
db_logger.debug(f"Total number of file queries: {total}")
# sort
if orderby:
order_attr = getattr(File, orderby, None)
if order_attr is not None:
if desc:
query = query.order_by(order_attr.desc())
else:
query = query.order_by(order_attr.asc())
db_logger.debug(f"sort: {orderby}, desc={desc}")
# pagination
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
db_logger.info(f"The file paging query has been successful: total={total}, Number of current page={len(items)}")
return total, [file_schema.File.model_validate(item) for item in items]
except Exception as e:
db_logger.error(f"Querying file pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
raise
def create_file(db: Session, file: file_schema.FileCreate) -> File:
db_logger.debug(f"Create a file record: filename={file.file_name}")
try:
db_file = File(**file.model_dump())
db.add(db_file)
db.commit()
db_logger.info(f"File record created successfully: {file.file_name} (ID: {db_file.id})")
return db_file
except Exception as e:
db_logger.error(f"Failed to create a file record: filename={file.file_name} - {str(e)}")
db.rollback()
raise
def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None:
db_logger.debug(f"Query file based on ID: file_id={file_id}")
try:
file = db.query(File).filter(File.id == file_id).first()
if file:
db_logger.debug(f"File query successful: {file.file_name} (ID: {file_id})")
else:
db_logger.debug(f"File does not exist: file_id={file_id}")
return file
except Exception as e:
db_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}")
raise
def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None) -> list | None:
db_logger.debug(f"Query file based on folder ID: parent_id={parent_id}")
try:
query = db.query(File)
if parent_id:
query = query.filter(File.parent_id == parent_id)
files = query.all()
db_logger.debug(f"Folder query file successful: parent_id={parent_id}, file_num={len(files)}")
return files
except Exception as e:
db_logger.error(f"Failed to query files based on folder ID: parent_id={parent_id} - {str(e)}")
raise
def delete_file_by_id(db: Session, file_id: uuid.UUID):
db_logger.debug(f"Delete file record: file_id={file_id}")
try:
# First, query the file information for logging purposes
file = db.query(File).filter(File.id == file_id).first()
if file:
filename = file.file_name
else:
filename = "unknown"
result = db.query(File).filter(File.id == file_id).delete()
db.commit()
if result > 0:
db_logger.info(f"File record deleted successfully: {filename} (ID: {file_id})")
else:
db_logger.warning(f"The file record does not exist, and cannot be deleted: file_id={file_id}")
except Exception as e:
db_logger.error(f"Failed to delete file record: file_id={file_id} - {str(e)}")
db.rollback()
raise

View File

@@ -0,0 +1,243 @@
"""
Generic File Repository
Handles database operations for generic file uploads.
"""
import uuid
from typing import Optional, List, Tuple, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from app.models.generic_file_model import GenericFile
from app.core.upload_enums import UploadContext
from app.core.logging_config import get_db_logger
# Get database logger
db_logger = get_db_logger()
class GenericFileRepository:
"""Repository for generic file operations"""
def __init__(self, db: Session):
self.db = db
def create_file(self, file_data: Dict[str, Any]) -> GenericFile:
"""
Create a new file record in the database.
Args:
file_data: Dictionary containing file information
Returns:
GenericFile: Created file record
Raises:
Exception: If database operation fails
"""
db_logger.debug(f"Creating file record: filename={file_data.get('file_name')}")
try:
db_file = GenericFile(**file_data)
self.db.add(db_file)
self.db.flush()
db_logger.info(f"File record created successfully: {file_data.get('file_name')} (ID: {db_file.id})")
return db_file
except Exception as e:
db_logger.error(f"Failed to create file record: filename={file_data.get('file_name')} - {str(e)}")
raise
def get_file_by_id(self, file_id: uuid.UUID) -> Optional[GenericFile]:
"""
Get a file by its ID.
Args:
file_id: UUID of the file
Returns:
Optional[GenericFile]: File record if found, None otherwise
"""
db_logger.debug(f"Querying file by ID: file_id={file_id}")
try:
file = self.db.query(GenericFile).filter(
and_(
GenericFile.id == file_id,
GenericFile.deleted_at.is_(None)
)
).first()
if file:
db_logger.debug(f"File found: {file.file_name} (ID: {file_id})")
else:
db_logger.debug(f"File not found: file_id={file_id}")
return file
except Exception as e:
db_logger.error(f"Failed to query file by ID: file_id={file_id} - {str(e)}")
raise
def update_file(self, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]:
"""
Update file metadata.
Args:
file_id: UUID of the file to update
update_data: Dictionary containing fields to update
Returns:
Optional[GenericFile]: Updated file record if found, None otherwise
"""
db_logger.debug(f"Updating file: file_id={file_id}")
try:
file = self.get_file_by_id(file_id)
if not file:
db_logger.debug(f"File not found for update: file_id={file_id}")
return None
# Update allowed fields
for field, value in update_data.items():
if hasattr(file, field) and field not in ['id', 'created_by', 'created_at', 'tenant_id']:
setattr(file, field, value)
# Update timestamp
file.updated_at = datetime.now()
self.db.flush()
db_logger.info(f"File updated successfully: {file.file_name} (ID: {file_id})")
return file
except Exception as e:
db_logger.error(f"Failed to update file: file_id={file_id} - {str(e)}")
raise
def delete_file(self, file_id: uuid.UUID) -> bool:
"""
Soft delete a file by setting deleted_at timestamp.
Args:
file_id: UUID of the file to delete
Returns:
bool: True if file was deleted, False if not found
"""
db_logger.debug(f"Soft deleting file: file_id={file_id}")
try:
file = self.get_file_by_id(file_id)
if not file:
db_logger.debug(f"File not found for deletion: file_id={file_id}")
return False
# Soft delete by setting deleted_at
file.deleted_at = datetime.now()
file.status = "deleted"
file.updated_at = datetime.now()
self.db.flush()
db_logger.info(f"File soft deleted successfully: {file.file_name} (ID: {file_id})")
return True
except Exception as e:
db_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}")
raise
def get_files_by_context(
self,
context: UploadContext,
tenant_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
status: Optional[str] = "active",
created_by: Optional[uuid.UUID] = None
) -> Tuple[int, List[GenericFile]]:
"""
Get files by context with pagination.
Args:
context: Upload context (avatar, app_icon, etc.)
tenant_id: Tenant ID for isolation
page: Page number (1-indexed)
pagesize: Number of items per page
status: File status filter (default: "active")
created_by: Optional filter by creator user ID
Returns:
Tuple[int, List[GenericFile]]: Total count and list of files
"""
db_logger.debug(
f"Querying files by context: context={context}, tenant_id={tenant_id}, "
f"page={page}, pagesize={pagesize}, status={status}"
)
try:
query = self.db.query(GenericFile).filter(
and_(
GenericFile.context == context,
GenericFile.tenant_id == tenant_id,
GenericFile.deleted_at.is_(None)
)
)
# Apply status filter
if status:
query = query.filter(GenericFile.status == status)
# Apply creator filter
if created_by:
query = query.filter(GenericFile.created_by == created_by)
# Get total count
total = query.count()
db_logger.debug(f"Total files found: {total}")
# Apply pagination and ordering
files = query.order_by(GenericFile.created_at.desc()).offset((page - 1) * pagesize).limit(pagesize).all()
db_logger.info(
f"Files query successful: context={context}, total={total}, "
f"returned={len(files)}"
)
return total, files
except Exception as e:
db_logger.error(
f"Failed to query files by context: context={context}, "
f"tenant_id={tenant_id} - {str(e)}"
)
raise
# Convenience functions for backward compatibility
def create_file(db: Session, file_data: Dict[str, Any]) -> GenericFile:
"""Create a new file record"""
return GenericFileRepository(db).create_file(file_data)
def get_file_by_id(db: Session, file_id: uuid.UUID) -> Optional[GenericFile]:
"""Get a file by its ID"""
return GenericFileRepository(db).get_file_by_id(file_id)
def update_file(db: Session, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]:
"""Update file metadata"""
return GenericFileRepository(db).update_file(file_id, update_data)
def delete_file(db: Session, file_id: uuid.UUID) -> bool:
"""Soft delete a file"""
return GenericFileRepository(db).delete_file(file_id)
def get_files_by_context(
db: Session,
context: UploadContext,
tenant_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
status: Optional[str] = "active",
created_by: Optional[uuid.UUID] = None
) -> Tuple[int, List[GenericFile]]:
"""Get files by context with pagination"""
return GenericFileRepository(db).get_files_by_context(
context, tenant_id, page, pagesize, status, created_by
)

View File

@@ -0,0 +1,211 @@
import uuid
from sqlalchemy.orm import Session
from app.models.knowledge_model import Knowledge
from app.schemas import knowledge_schema
from app.core.logging_config import get_db_logger
# Obtain a dedicated logger for the database
db_logger = get_db_logger()
def get_knowledges_paginated(
db: Session,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
"""
Paged query knowledge base (with filtering and sorting)
"""
db_logger.debug(f"Query knowledge base in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
try:
query = db.query(Knowledge)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Calculate the total count (for pagination)
total = query.count()
db_logger.debug(f"Total number of knowledge base queries: {total}")
# sort
if orderby:
order_attr = getattr(Knowledge, orderby, None)
if order_attr is not None:
if desc:
query = query.order_by(order_attr.desc())
else:
query = query.order_by(order_attr.asc())
db_logger.debug(f"sort: {orderby}, desc={desc}")
# pagination
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
db_logger.info(f"The knowledge base paging query has been successful: total={total}, Number of current page={len(items)}")
return total, [knowledge_schema.Knowledge.model_validate(item) for item in items]
except Exception as e:
db_logger.error(f"Querying knowledge base pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
raise
def get_chunded_knowledgeids(
db: Session,
filters: list
) -> list:
"""
Query the list of vectorized knowledge base IDs
Return: list[UUID] - List of knowledge base IDs
"""
db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}")
try:
# Only query the id field
query = db.query(Knowledge.id)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Get all IDs
items = query.all()
db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
return [item[0] for item in items]
except Exception as e:
db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}")
raise
def create_knowledge(db: Session, knowledge: knowledge_schema.KnowledgeCreate) -> Knowledge:
db_logger.debug(f"Create a knowledge base record: name={knowledge.name}")
try:
db_knowledge = Knowledge(**knowledge.model_dump())
db.add(db_knowledge)
db.commit()
db_logger.info(f"knowledge base record created successfully: {knowledge.name} (ID: {db_knowledge.id})")
return db_knowledge
except Exception as e:
db_logger.error(f"Failed to create a knowledge base record: name={knowledge.name} - {str(e)}")
db.rollback()
raise
def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | None:
db_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}")
try:
knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first()
if knowledge:
db_logger.debug(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})")
else:
db_logger.debug(f"knowledge base does not exist: knowledge_id={knowledge_id}")
return knowledge
except Exception as e:
db_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}")
raise
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
try:
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
if knowledge:
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
else:
db_logger.debug(f"knowledge base does not exist: name={name}, workspace_id={workspace_id}")
return knowledge
except Exception as e:
db_logger.error(f"Failed to query the knowledge base based on the name and workspace_id: name={name}, workspace_id={workspace_id} - {str(e)}")
raise
def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID):
db_logger.debug(f"Delete knowledge base record: knowledge_id={knowledge_id}")
try:
# First, query the knowledge base information for logging purposes
knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first()
if knowledge:
knowledge_name = knowledge.name
else:
knowledge_name = "unknown"
result = db.query(Knowledge).filter(Knowledge.id == knowledge_id).delete()
db.commit()
if result > 0:
db_logger.info(f"knowledge base record deleted successfully: {knowledge_name} (ID: {knowledge_id})")
else:
db_logger.warning(f"The knowledge base record does not exist, and cannot be deleted: knowledge_id={knowledge_id}")
except Exception as e:
db_logger.error(f"Failed to delete knowledge base record: knowledge_id={knowledge_id} - {str(e)}")
db.rollback()
raise
def get_total_doc_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表所有doc_num的总和
"""
db_logger.debug(f"Query total doc_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.doc_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1
).scalar()
total = result if result is not None else 0
db_logger.info(f"Total doc_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query total doc_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_total_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表所有chunk_num的总和
"""
db_logger.debug(f"Query total chunk_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.chunk_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1
).scalar()
total = result if result is not None else 0
db_logger.info(f"Total chunk_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query total chunk_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表所有不同id的数量知识库总数
"""
db_logger.debug(f"Query total knowledge base count by workspace_id: workspace_id={workspace_id}")
try:
count = db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1
).count()
db_logger.info(f"Total knowledge base count query successful: workspace_id={workspace_id}, count={count}")
return count
except Exception as e:
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
raise

View File

@@ -0,0 +1,142 @@
import uuid
from sqlalchemy.orm import Session
from app.models.knowledgeshare_model import KnowledgeShare
from app.schemas import knowledgeshare_schema
from app.core.logging_config import get_db_logger
from sqlalchemy.orm import joinedload
from sqlalchemy import or_
# Obtain a dedicated logger for the database
db_logger = get_db_logger()
def get_knowledgeshares_paginated(
db: Session,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
"""
Paged query knowledge base sharing (with filtering and sorting)
"""
db_logger.debug(
f"Query knowledge base sharing in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
try:
query = db.query(KnowledgeShare)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Calculate the total count (for pagination)
total = query.count()
db_logger.debug(f"Total number of knowledge base sharing queries: {total}")
# sort
if orderby:
order_attr = getattr(KnowledgeShare, orderby, None)
if order_attr is not None:
if desc:
query = query.order_by(order_attr.desc())
else:
query = query.order_by(order_attr.asc())
db_logger.debug(f"sort: {orderby}, desc={desc}")
# pagination
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
db_logger.info(f"The knowledge base sharing paging query has been successful: total={total}, Number of current page={len(items)}")
return total, [knowledgeshare_schema.KnowledgeShare.model_validate(item) for item in items]
except Exception as e:
db_logger.error(f"Querying knowledge base sharing pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
raise
def get_source_kb_ids_by_target_kb_id(
db: Session,
filters: list
) -> list:
"""
Query the original knowledge base ID list by sharing the knowledge base
Return: list[UUID] - List of knowledge base IDs
"""
db_logger.debug(
f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}")
try:
# Only query the id field
query = db.query(KnowledgeShare.source_kb_id)
# Apply filter conditions
for filter_cond in filters:
query = query.filter(filter_cond)
# Get all IDs
items = query.all()
db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
return [item[0] for item in items]
except Exception as e:
db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}")
raise
def create_knowledgeshare(db: Session, knowledgeshare: knowledgeshare_schema.KnowledgeShareCreate) -> KnowledgeShare:
db_logger.debug(f"Create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id}")
try:
db_knowledgeshare = KnowledgeShare(**knowledgeshare.model_dump())
db.add(db_knowledgeshare)
db.commit()
db_logger.info(f"knowledge base sharing record created successfully: (ID: {db_knowledgeshare.id})")
return db_knowledgeshare
except Exception as e:
db_logger.error(f"Failed to create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id} - {str(e)}")
db.rollback()
raise
def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID) -> KnowledgeShare | None:
db_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}")
try:
knowledgeshare = db.query(KnowledgeShare).filter(
or_(
KnowledgeShare.id == knowledgeshare_id,
KnowledgeShare.target_kb_id == knowledgeshare_id
)
).first()
if knowledgeshare:
db_logger.debug(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})")
else:
db_logger.debug(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
return knowledgeshare
except Exception as e:
db_logger.error(f"Failed to query the knowledge base sharing based on the ID: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise
def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID):
db_logger.debug(f"Delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id}")
try:
result = db.query(KnowledgeShare).filter(
or_(
KnowledgeShare.id == knowledgeshare_id,
KnowledgeShare.target_kb_id == knowledgeshare_id
)
).delete()
db.commit()
if result > 0:
db_logger.info(f"knowledge base sharing record deleted successfully: (ID: {knowledgeshare_id})")
else:
db_logger.warning(f"The knowledge base sharing record does not exist, and cannot be deleted: knowledgeshare_id={knowledgeshare_id}")
except Exception as e:
db_logger.error(f"Failed to delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
db.rollback()
raise

View File

@@ -0,0 +1,110 @@
from sqlalchemy import func
from sqlalchemy.orm import Session, aliased
from typing import List, Optional
import uuid
import datetime
from app.models.memory_increment_model import MemoryIncrement
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class MemoryIncrementRepository:
def __init__(self, db: Session):
self.db = db
def get_memory_increments_by_workspace_id(self, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]:
"""根据工作空间ID查询内存增量通过 MemoryIncrement 关联查询 MemoryIncrement 列表"""
try:
# 使用窗口函数按日期分区并排序
subquery = (
self.db.query(
MemoryIncrement,
func.row_number().over(
partition_by=func.date(MemoryIncrement.created_at), # 按日期分区
order_by=MemoryIncrement.created_at.desc() # 按时间戳升序排序
).label('row_num')
)
.filter(MemoryIncrement.workspace_id == workspace_id)
.subquery()
)
memory_increment_alias = aliased(MemoryIncrement, subquery)
memory_increments = (
self.db.query(memory_increment_alias)
.filter(subquery.c.row_num == 1) # 只取每个日期的第一条(最新的)
.order_by(memory_increment_alias.created_at.asc()) # 按时间戳降序排序
.limit(limit)
.all()
)
db_logger.info(f"成功查询工作空间 {workspace_id} 下的内存增量")
return memory_increments
except Exception as e:
db_logger.error(f"查询工作空间 {workspace_id} 下内存增量时出错: {str(e)}")
raise
def get_latest_memory_increment_by_workspace_id(self, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]:
"""根据工作空间ID查询最新的内存增量记录"""
try:
memory_increment = (
self.db.query(MemoryIncrement)
.filter(MemoryIncrement.workspace_id == workspace_id)
.order_by(MemoryIncrement.created_at.desc(), MemoryIncrement.id.desc())
.first()
)
if memory_increment:
db_logger.info(f"成功查询工作空间 {workspace_id} 下的最新内存增量")
else:
db_logger.warning(f"未找到工作空间 {workspace_id} 下的内存增量记录")
return memory_increment
except Exception as e:
db_logger.error(f"查询工作空间 {workspace_id} 下最新内存增量时出错: {str(e)}")
raise
def write_memory_increment(
self,
workspace_id: uuid.UUID,
total_num: int
) -> MemoryIncrement:
"""写入内存增量"""
try:
memory_increment = MemoryIncrement(
workspace_id=workspace_id,
total_num=total_num,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now()
)
self.db.add(memory_increment)
self.db.commit()
self.db.refresh(memory_increment)
db_logger.info(f"成功写入内存增量: workspace_id={workspace_id}, total_num={total_num}")
return memory_increment
except Exception as e:
db_logger.error(f"写入内存增量失败: workspace_id={workspace_id}, total_num={total_num} - {str(e)}")
raise
def get_memory_increments_by_workspace_id(db: Session, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]:
"""根据工作空间ID查询内存增量返回 MemoryIncrement ORM 列表)"""
repo = MemoryIncrementRepository(db)
memory_increments = repo.get_memory_increments_by_workspace_id(workspace_id, limit)
return memory_increments
def write_memory_increment(
db: Session,
workspace_id: uuid.UUID,
total_num: int
) -> MemoryIncrement:
"""写入内存增量"""
repo = MemoryIncrementRepository(db)
memory_increment = repo.write_memory_increment(workspace_id, total_num)
return memory_increment
def get_latest_memory_increment_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]:
"""根据工作空间ID查询最新的内存增量记录"""
repo = MemoryIncrementRepository(db)
return repo.get_latest_memory_increment_by_workspace_id(workspace_id)

View File

@@ -0,0 +1,386 @@
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, func, desc
from typing import List, Optional, Dict, Any, Tuple
import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery
)
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class ModelConfigRepository:
"""模型配置Repository"""
@staticmethod
def get_by_id(db: Session, model_id: uuid.UUID) -> Optional[ModelConfig]:
"""根据ID获取模型配置"""
db_logger.debug(f"根据ID查询模型配置: model_id={model_id}")
try:
model = db.query(ModelConfig).options(
joinedload(ModelConfig.api_keys)
).filter(ModelConfig.id == model_id).first()
if model:
db_logger.debug(f"模型配置查询成功: {model.name} (ID: {model_id})")
else:
db_logger.debug(f"模型配置不存在: model_id={model_id}")
return model
except Exception as e:
db_logger.error(f"根据ID查询模型配置失败: model_id={model_id} - {str(e)}")
raise
@staticmethod
def get_by_name(db: Session, name: str) -> Optional[ModelConfig]:
"""根据名称获取模型配置"""
db_logger.debug(f"根据名称查询模型配置: name={name}")
try:
model = db.query(ModelConfig).filter(ModelConfig.name == name).first()
if model:
db_logger.debug(f"模型配置查询成功: {model.name}")
return model
except Exception as e:
db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
raise
@staticmethod
def search_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
"""按名称模糊匹配获取模型配置列表
Args:
name: 模型名称关键词(模糊匹配)
limit: 返回数量上限
Returns:
模型配置列表
"""
db_logger.debug(f"按名称模糊查询模型配置: name~{name}, limit={limit}")
try:
models = (
db.query(ModelConfig)
.filter(ModelConfig.name.ilike(f"%{name}%"))
.order_by(ModelConfig.name)
.limit(limit)
.all()
)
db_logger.debug(f"模糊查询成功: 返回数量={len(models)}")
return models
except Exception as e:
db_logger.error(f"按名称模糊查询模型配置失败: name~{name} - {str(e)}")
raise
@staticmethod
def get_list(db: Session, query: ModelConfigQuery) -> Tuple[List[ModelConfig], int]:
"""获取模型配置列表"""
db_logger.debug(f"查询模型配置列表: {query.dict()}")
try:
# 构建查询条件
filters = []
# 支持多个 type 值(使用 IN 查询)
if query.type:
filters.append(ModelConfig.type.in_(query.type))
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
if query.is_public is not None:
filters.append(ModelConfig.is_public == query.is_public)
if query.search:
# 搜索逻辑需要join ModelApiKey表来搜索model_name
search_filter = or_(
ModelConfig.name.ilike(f"%{query.search}%"),
# ModelConfig.description.ilike(f"%{query.search}%")
)
filters.append(search_filter)
# 构建基础查询
base_query = db.query(ModelConfig).options(
joinedload(ModelConfig.api_keys)
)
# 如果需要按provider筛选需要join ModelApiKey表
if query.provider:
base_query = base_query.join(ModelApiKey).filter(
ModelApiKey.provider == query.provider
).distinct()
if filters:
base_query = base_query.filter(and_(*filters))
# 获取总数
total = base_query.count()
# 分页查询
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
(query.page - 1) * query.pagesize
).limit(query.pagesize).all()
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
return models, total
except Exception as e:
db_logger.error(f"查询模型配置列表失败: {str(e)}")
raise
@staticmethod
def get_by_type(db: Session, model_type: ModelType, is_active: bool = True) -> List[ModelConfig]:
"""根据类型获取模型配置"""
db_logger.debug(f"根据类型查询模型配置: type={model_type}, is_active={is_active}")
try:
query = db.query(ModelConfig).options(
joinedload(ModelConfig.api_keys)
).filter(ModelConfig.type == model_type)
if is_active:
query = query.filter(ModelConfig.is_active == True)
models = query.order_by(ModelConfig.name).all()
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
return models
except Exception as e:
db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}")
raise
@staticmethod
def create(db: Session, model_data: dict) -> ModelConfig:
"""创建模型配置"""
db_logger.debug(f"创建模型配置: {model_data.get('name')}")
try:
db_model = ModelConfig(**model_data)
db.add(db_model)
db_logger.info(f"模型配置已添加到会话: {db_model.name}")
return db_model
except Exception as e:
db.rollback()
db_logger.error(f"创建模型配置失败: {model_data.get('name')} - {str(e)}")
raise
@staticmethod
def update(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> Optional[ModelConfig]:
"""更新模型配置"""
db_logger.debug(f"更新模型配置: model_id={model_id}")
try:
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
if not db_model:
db_logger.warning(f"模型配置不存在: model_id={model_id}")
return None
# 更新字段
update_data = model_data.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_model, field, value)
db.commit()
db.refresh(db_model)
db_logger.info(f"模型配置更新成功: {db_model.name} (ID: {model_id})")
return db_model
except Exception as e:
db.rollback()
db_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
raise
@staticmethod
def delete(db: Session, model_id: uuid.UUID) -> bool:
"""删除模型配置"""
db_logger.debug(f"删除模型配置: model_id={model_id}")
try:
db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first()
if not db_model:
db_logger.warning(f"模型配置不存在: model_id={model_id}")
return False
db.delete(db_model)
db.commit()
db_logger.info(f"模型配置删除成功: model_id={model_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
raise
@staticmethod
def get_stats(db: Session) -> Dict[str, Any]:
"""获取模型统计信息"""
db_logger.debug("获取模型统计信息")
try:
# 总数统计
total_models = db.query(ModelConfig).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
# 按类型统计
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
embedding_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.EMBEDDING).count()
rerank_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.RERANK).count()
# 按提供商统计 - 现在从ModelApiKey表获取
provider_stats = {}
provider_results = db.query(
ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id))
).group_by(ModelApiKey.provider).all()
for provider, count in provider_results:
provider_stats[provider.value] = count
stats = {
"total_models": total_models,
"active_models": active_models,
"llm_count": llm_count,
"embedding_count": embedding_count,
"rerank_count": rerank_count,
"provider_stats": provider_stats
}
db_logger.debug(f"模型统计信息获取成功: {stats}")
return stats
except Exception as e:
db_logger.error(f"获取模型统计信息失败: {str(e)}")
raise
class ModelApiKeyRepository:
"""模型API Key Repository"""
@staticmethod
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ModelApiKey]:
"""根据ID获取API Key"""
db_logger.debug(f"根据ID查询API Key: api_key_id={api_key_id}")
try:
api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
if api_key:
db_logger.debug(f"API Key查询成功: {api_key.model_name} (ID: {api_key_id})")
return api_key
except Exception as e:
db_logger.error(f"根据ID查询API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@staticmethod
def get_by_model_config(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
"""根据模型配置ID获取API Key列表"""
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
try:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
if is_active:
query = query.filter(ModelApiKey.is_active == True)
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
return api_keys
except Exception as e:
db_logger.error(f"根据模型配置ID查询API Key失败: model_config_id={model_config_id} - {str(e)}")
raise
@staticmethod
def create(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
"""创建API Key"""
db_logger.debug(f"创建API Key: {api_key_data.provider}")
try:
db_api_key = ModelApiKey(**api_key_data.dict())
db.add(db_api_key)
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
return db_api_key
except Exception as e:
db.rollback()
db_logger.error(f"创建API Key失败: {api_key_data.provider} - {str(e)}")
raise
@staticmethod
def update(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> Optional[ModelApiKey]:
"""更新API Key"""
db_logger.debug(f"更新API Key: api_key_id={api_key_id}")
try:
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
if not db_api_key:
db_logger.warning(f"API Key不存在: api_key_id={api_key_id}")
return None
# 更新字段
update_data = api_key_data.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_api_key, field, value)
db.commit()
db.refresh(db_api_key)
db_logger.info(f"API Key更新成功: {db_api_key.model_name} (ID: {api_key_id})")
return db_api_key
except Exception as e:
db.rollback()
db_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@staticmethod
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
"""删除API Key"""
db_logger.debug(f"删除API Key: api_key_id={api_key_id}")
try:
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
if not db_api_key:
db_logger.warning(f"API Key不存在: api_key_id={api_key_id}")
return False
db.delete(db_api_key)
db.commit()
db_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@staticmethod
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
"""更新API Key使用统计"""
db_logger.debug(f"更新API Key使用统计: api_key_id={api_key_id}")
try:
db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first()
if not db_api_key:
return False
# 更新使用次数和最后使用时间
current_count = int(db_api_key.usage_count or "0")
db_api_key.usage_count = str(current_count + 1)
db_api_key.last_used_at = func.now()
db.commit()
db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
raise

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""Neo4j仓储模块
本模块包含Neo4j图数据库的仓储实现用于管理知识图谱的节点和边。
Modules:
neo4j_connector: Neo4j数据库连接器
base_neo4j_repository: Neo4j仓储基类
dialog_repository: 对话仓储
statement_repository: 陈述句仓储
entity_repository: 实体仓储
cypher_queries: Cypher查询语句
graph_search: 图搜索功能
graph_saver: 图数据保存功能
add_nodes: 添加节点功能
add_edges: 添加边功能
create_indexes: 创建索引功能
"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.repositories.neo4j.dialog_repository import DialogRepository
from app.repositories.neo4j.statement_repository import StatementRepository
from app.repositories.neo4j.entity_repository import EntityRepository
__all__ = [
'Neo4jConnector',
'BaseNeo4jRepository',
'DialogRepository',
'StatementRepository',
'EntityRepository',
]

View File

@@ -0,0 +1,102 @@
from typing import List, Optional
import hashlib
from datetime import datetime
from uuid import uuid4
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE
from app.core.memory.models.message_models import Chunk
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import MemorySummaryNode
async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add edges between chunk nodes and their statement nodes in Neo4j.
Args:
chunks: List of Chunk objects containing the statements
connector: Neo4j connector instance
Returns:
List of created edge UUIDs or None if failed
"""
if not chunks:
print("No chunks provided to create edges")
return []
try:
# Build edges deterministically per (chunk, statement) pair
edges: List[dict] = []
for chunk in chunks:
for stmt in getattr(chunk, "statements", []) or []:
stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest()
edge = {
"id": stable_edge_id,
"source": chunk.id,
"target": stmt.id,
"group_id": getattr(stmt, 'group_id', None),
"user_id":getattr(stmt, 'user_id', None),
"apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
"created_at": getattr(stmt, 'created_at', None),
"expired_at": getattr(stmt, 'expired_at', None),
# "created_at": getattr(statement, 'created_at', None),
# "expired_at": None # Set to None or appropriate default
}
edges.append(edge)
if not edges:
print("No statements found in chunks to create edges")
return []
# Execute the query to create edges
result = await connector.execute_query(
CHUNK_STATEMENT_EDGE_SAVE,
chunk_statement_edges=edges
)
created_uuids = [record.get("uuid") for record in result] if result else []
print(f"Successfully created {len(created_uuids)} chunk-statement edges")
return created_uuids
except Exception as e:
print(f"Error creating chunk-statement edges: {e}")
return None
async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Create edges from MemorySummary to Statements via their chunk_ids.
For each summary and each chunk_id in it, this links the summary to all statements
contained in that chunk using DERIVED_FROM_STATEMENT. This supports queries like
summary -> statement -> entity with minimal hops.
Args:
summaries: List of MemorySummaryNode objects
connector: Neo4j connector instance
Returns:
List of created edge elementIds or None if failed
"""
if not summaries:
return []
try:
edges: List[dict] = []
for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []:
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
"group_id": s.group_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
})
if not edges:
return []
result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges
)
created = [record.get("uuid") for record in result] if result else []
return created
except Exception:
return None

View File

@@ -0,0 +1,215 @@
from typing import List, Optional
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(group_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n")
print(f"All group_id: {group_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database.
Args:
dialogues: List of DialogueNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not dialogues:
print("No dialogues to save")
return []
try:
# Flatten DialogueNode objects to match Cypher expected fields
flattened_dialogues = []
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"group_id": dialogue.group_id,
"user_id": dialogue.user_id,
"apply_id": dialogue.apply_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
"content": dialogue.content,
"dialog_embedding": dialogue.dialog_embedding
})
result = await connector.execute_query(
DIALOGUE_NODE_SAVE,
dialogues=flattened_dialogues
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids
except Exception as e:
print(f"Error creating dialogue nodes: {e}")
return None
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add statement nodes to Neo4j database.
Args:
statements: List of StatementNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not statements:
print("No statements to save")
return []
try:
# Flatten StatementNode objects to only include primitive types
flattened_statements = []
for statement in statements:
flattened_statement = {
"id": statement.id,
"name": statement.name,
"group_id": statement.group_id,
"user_id": statement.user_id,
"apply_id": statement.apply_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
"created_at": statement.created_at.isoformat() if statement.created_at else None,
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
"stmt_type": statement.stmt_type,
"temporal_info": statement.temporal_info.value,
"statement": statement.statement,
"connect_strength": statement.connect_strength,
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
# "triplet_extraction_info": json.dumps({
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
}
flattened_statements.append(flattened_statement)
result = await connector.execute_query(
STATEMENT_NODE_SAVE,
statements=flattened_statements
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids
except Exception as e:
print(f"Error creating statement nodes: {e}")
return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch.
Args:
chunks: List of ChunkNode objects to add
connector: Neo4j connector instance
Returns:
List of created chunk UUIDs or None if failed
"""
if not chunks:
print("No chunk nodes to add")
return []
try:
# Convert chunk nodes to dictionaries for the query
flattened_chunks = []
for chunk in chunks:
# Flatten metadata properties to avoid Neo4j Map type issues
metadata = chunk.metadata if chunk.metadata else {}
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"group_id": chunk.group_id,
"user_id": chunk.user_id,
"apply_id": chunk.apply_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
"dialog_id": chunk.dialog_id,
"content": chunk.content,
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index")
}
flattened_chunks.append(flattened_chunk)
result = await connector.execute_query(
CHUNK_NODE_SAVE,
chunks=flattened_chunks
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids
except Exception as e:
print(f"Error creating chunk nodes: {e}")
return None
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch.
Args:
summaries: List of MemorySummaryNode objects to add
connector: Neo4j connector instance
Returns:
List of created summary node ids or None if failed
"""
if not summaries:
print("No memory summary nodes to add")
return []
try:
flattened = []
for s in summaries:
flattened.append({
"id": s.id,
"name": s.name,
"group_id": s.group_id,
"user_id": s.user_id,
"apply_id": s.apply_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
"dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids,
"content": s.content,
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})
result = await connector.execute_query(
MEMORY_SUMMARY_NODE_SAVE,
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
return created_ids
except Exception:
return None

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Neo4j仓储基类模块
本模块提供Neo4j仓储的基类实现封装了通用的Neo4j节点操作。
Classes:
BaseNeo4jRepository: Neo4j仓储基类实现通用的CRUD操作
"""
from typing import List, Optional, Dict, Any, TypeVar
from app.repositories.base_repository import BaseRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
T = TypeVar('T')
class BaseNeo4jRepository(BaseRepository[T]):
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
这个基类封装了Neo4j节点的通用CRUD操作子类只需要实现
特定的映射逻辑和业务查询方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签(如"Dialogue", "Statement"等)
Type Parameters:
T: 实体类型通常是Pydantic模型
"""
def __init__(self, connector: Neo4jConnector, node_label: str):
"""初始化Neo4j仓储
Args:
connector: Neo4j连接器实例
node_label: 节点标签用于Cypher查询
"""
self.connector = connector
self.node_label = node_label
async def create(self, entity: T) -> T:
"""创建节点
将实体对象转换为Neo4j节点并保存到数据库。
Args:
entity: 要创建的实体对象
Returns:
T: 创建后的实体对象
Example:
>>> dialog = DialogueNode(id="123", name="对话1", ...)
>>> created = await repository.create(dialog)
"""
query = f"""
CREATE (n:{self.node_label} $props)
RETURN n
"""
result = await self.connector.execute_query(
query,
props=entity.model_dump()
)
return entity
async def get_by_id(self, entity_id: str) -> Optional[T]:
"""根据ID获取节点
Args:
entity_id: 节点ID
Returns:
Optional[T]: 找到的实体对象如果不存在则返回None
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
RETURN n
"""
result = await self.connector.execute_query(query, id=entity_id)
if result:
return self._map_to_entity(result[0])
return None
async def update(self, entity: T) -> T:
"""更新节点
更新现有节点的属性。使用SET +=语法合并属性。
Args:
entity: 要更新的实体对象必须包含id字段
Returns:
T: 更新后的实体对象
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
SET n += $props
RETURN n
"""
await self.connector.execute_query(
query,
id=entity.id,
props=entity.model_dump()
)
return entity
async def delete(self, entity_id: str) -> bool:
"""删除节点
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
Args:
entity_id: 要删除的节点ID
Returns:
bool: 删除成功返回True否则返回False
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
DETACH DELETE n
RETURN count(n) as deleted
"""
result = await self.connector.execute_query(query, id=entity_id)
return result[0]['deleted'] > 0 if result else False
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
"""查询节点
根据过滤条件查询节点列表。
Args:
filters: 查询条件字典,键为属性名,值为期望的值
limit: 返回结果的最大数量
Returns:
List[T]: 符合条件的实体列表
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""
# 构建查询条件
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
limit=limit,
**filters
)
return [self._map_to_entity(r) for r in results]
def _map_to_entity(self, node_data: Dict) -> T:
"""将节点数据映射为实体对象
这是一个抽象方法,子类必须实现具体的映射逻辑。
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
T: 映射后的实体对象
Raises:
NotImplementedError: 如果子类未实现此方法
"""
raise NotImplementedError("Subclasses must implement _map_to_entity method")

View File

@@ -0,0 +1,332 @@
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Full-Text Indexes (for keyword search)")
print("=" * 70)
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: statementsFulltext")
# # 创建 Dialogues 索引
# await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
# """)
# 创建 Entities 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: entitiesFulltext")
# 创建 Chunks 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: chunksFulltext")
# 创建 MemorySummary 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: summariesFulltext")
print("\nFull-text indexes created successfully with BM25 support.")
except Exception as e:
print(f"✗ Error creating full-text indexes: {e}")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
Vector indexes provide 10-100x faster similarity search compared to manual cosine calculation.
This is critical for performance - reduces embedding search from ~1.4s to ~0.05-0.2s!
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Vector Indexes (for embedding search)")
print("=" * 70)
print("Note: Adjust vector.dimensions if using different embedding model")
print(" Current setting: 1024 dimensions (for bge-m3)")
print()
# Statement embedding index
await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
FOR (s:Statement)
ON s.statement_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: statement_embedding_index")
# Chunk embedding index
await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
FOR (c:Chunk)
ON c.chunk_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: chunk_embedding_index")
# Entity name embedding index
await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
FOR (e:ExtractedEntity)
ON e.name_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: entity_embedding_index")
# Memory summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
FOR (m:MemorySummary)
ON m.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: summary_embedding_index")
# Dialogue embedding index (optional)
await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
FOR (d:Dialogue)
ON d.dialog_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: dialogue_embedding_index")
print("\nVector indexes created successfully!")
print("\nExpected performance improvement:")
print(" Before: ~1.4s for embedding search")
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
except Exception as e:
print(f"✗ Error creating vector indexes: {e}")
finally:
await connector.close()
async def create_config_id_indexes():
"""Create indexes on config_id fields for improved query performance.
These indexes enable fast filtering of nodes by configuration ID,
which is essential for configuration isolation and multi-tenant scenarios.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Config ID Indexes")
print("=" * 70)
# Dialogue.config_id index
await connector.execute_query("""
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
FOR (d:Dialogue) ON (d.config_id)
""")
print("✓ Created: dialogue_config_id_index")
# Statement.config_id index
await connector.execute_query("""
CREATE INDEX statement_config_id_index IF NOT EXISTS
FOR (s:Statement) ON (s.config_id)
""")
print("✓ Created: statement_config_id_index")
# ExtractedEntity.config_id index
await connector.execute_query("""
CREATE INDEX entity_config_id_index IF NOT EXISTS
FOR (e:ExtractedEntity) ON (e.config_id)
""")
print("✓ Created: entity_config_id_index")
# MemorySummary.config_id index
await connector.execute_query("""
CREATE INDEX summary_config_id_index IF NOT EXISTS
FOR (m:MemorySummary) ON (m.config_id)
""")
print("✓ Created: summary_config_id_index")
print("\nConfig ID indexes created successfully!")
print("These indexes enable fast filtering by configuration ID.")
except Exception as e:
print(f"✗ Error creating config_id indexes: {e}")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Unique Constraints")
print("=" * 70)
# Dialogue.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT dialog_id_unique IF NOT EXISTS
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
"""
)
print("✓ Created: dialog_id_unique")
# Statement.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT statement_id_unique IF NOT EXISTS
FOR (s:Statement) REQUIRE s.id IS UNIQUE
"""
)
print("✓ Created: statement_id_unique")
# Chunk.id unique
await connector.execute_query(
"""
CREATE CONSTRAINT chunk_id_unique IF NOT EXISTS
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
"""
)
print("✓ Created: chunk_id_unique")
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
except Exception as e:
print(f"✗ Error creating unique constraints: {e}")
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
print("\n" + "=" * 70)
print("Neo4j Index & Constraint Setup")
print("=" * 70)
print("This will create:")
print(" 1. Full-text indexes (for keyword/BM25 search)")
print(" 2. Vector indexes (for embedding similarity search)")
print(" 3. Config ID indexes (for configuration isolation)")
print(" 4. Unique constraints (for data integrity)")
print("=" * 70)
await create_fulltext_indexes()
await create_vector_indexes()
await create_config_id_indexes()
await create_unique_constraints()
print("\n" + "=" * 70)
print("✓ All indexes and constraints created successfully!")
print("=" * 70)
print("\nTo verify, run in Neo4j Browser:")
print(" SHOW INDEXES")
print(" SHOW CONSTRAINTS")
print()
async def check_indexes():
"""Check what indexes currently exist."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Checking Existing Indexes")
print("=" * 70)
query = "SHOW INDEXES"
result = await connector.execute_query(query)
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
for idx in fulltext_indexes:
print(f"{idx.get('name')}")
print(f"\nVector indexes: {len(vector_indexes)}")
for idx in vector_indexes:
print(f"{idx.get('name')}")
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
for idx in range_indexes:
print(f"{idx.get('name')}")
if not vector_indexes:
print("\n⚠️ WARNING: No vector indexes found!")
print(" Embedding search will be VERY SLOW (~1.4s)")
print(" Run: python create_indexes.py")
# Check for config_id indexes
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
if len(config_id_indexes) < 4:
print("\n⚠️ WARNING: Not all config_id indexes found!")
print(f" Expected 4, found {len(config_id_indexes)}")
print(" Run: python create_indexes.py config_id")
print("=" * 70)
finally:
await connector.close()
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
command = sys.argv[1]
if command == "check":
asyncio.run(check_indexes())
elif command == "fulltext":
asyncio.run(create_fulltext_indexes())
elif command == "vector":
asyncio.run(create_vector_indexes())
elif command == "config_id":
asyncio.run(create_config_id_indexes())
elif command == "constraints":
asyncio.run(create_unique_constraints())
else:
print(f"Unknown command: {command}")
print("\nUsage:")
print(" python create_indexes.py # Create all indexes")
print(" python create_indexes.py check # Check existing indexes")
print(" python create_indexes.py fulltext # Create only full-text indexes")
print(" python create_indexes.py vector # Create only vector indexes")
print(" python create_indexes.py config_id # Create only config_id indexes")
print(" python create_indexes.py constraints # Create only constraints")
else:
asyncio.run(create_all_indexes())

View File

@@ -0,0 +1,684 @@
DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue
MERGE (n:Dialogue {id: dialogue.id})
SET n.uuid = coalesce(n.uuid, dialogue.id),
n.group_id = dialogue.group_id,
n.user_id = dialogue.user_id,
n.apply_id = dialogue.apply_id,
n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at,
n.expired_at = dialogue.expired_at,
n.content = dialogue.content,
n.dialog_embedding = dialogue.dialog_embedding
RETURN n.id AS uuid
"""
STATEMENT_NODE_SAVE = """
UNWIND $statements AS statement
MERGE (s:Statement {id: statement.id})
SET s += {
id: statement.id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
chunk_id: statement.chunk_id,
run_id: statement.run_id,
created_at: statement.created_at,
expired_at: statement.expired_at,
stmt_type: statement.stmt_type,
temporal_info: statement.temporal_info,
relevence_info: statement.relevence_info,
statement: statement.statement,
valid_at: statement.valid_at,
invalid_at: statement.invalid_at,
statement_embedding: statement.statement_embedding
}
RETURN s.id AS uuid
"""
CHUNK_NODE_SAVE = """
UNWIND $chunks AS chunk
MERGE (c:Chunk {id: chunk.id})
SET c += {
id: chunk.id,
name: chunk.name,
group_id: chunk.group_id,
user_id: chunk.user_id,
apply_id: chunk.apply_id,
run_id: chunk.run_id,
created_at: chunk.created_at,
expired_at: chunk.expired_at,
dialog_id: chunk.dialog_id,
content: chunk.content,
chunk_embedding: chunk.chunk_embedding,
sequence_number: chunk.sequence_number,
start_index: chunk.start_index,
end_index: chunk.end_index
}
RETURN c.id AS uuid
"""
# bug修改点
EXTRACTED_ENTITY_NODE_SAVE = """
// Upsert entity nodes safely: preserve existing non-empty fields when incoming is empty
UNWIND $entities AS entity
MERGE (e:ExtractedEntity {id: entity.id})
SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END,
e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END,
e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END,
e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END,
e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END,
e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
THEN entity.created_at ELSE e.created_at END,
e.expired_at = CASE
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
THEN entity.expired_at ELSE e.expired_at END,
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
e.description = CASE
WHEN entity.description IS NOT NULL AND entity.description <> ''
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
THEN entity.description ELSE e.description END,
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
e.aliases = CASE
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END
ELSE e.aliases END,
e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END,
e.fact_summary = CASE
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE
WHEN e.connect_strength = 'strong' AND entity.connect_strength = 'weak' THEN 'both'
WHEN e.connect_strength = 'weak' AND entity.connect_strength = 'strong' THEN 'both'
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
ELSE e.connect_strength
END
END
RETURN e.id AS uuid
"""
# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships
ENTITY_RELATIONSHIP_SAVE = """
UNWIND $relationships AS rel
// Match entities by stable id within group, do not constrain by run_id
MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id})
MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
r.statement_id = rel.statement_id,
r.value = rel.value,
r.statement = rel.statement,
r.valid_at = rel.valid_at,
r.invalid_at = rel.invalid_at,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id,
r.group_id = rel.group_id
RETURN elementId(r) AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
# 保存弱关系实体,设置 e.is_weak = true不维护 e.relations 聚合字段
WEAK_ENTITY_NODE_SAVE = """
UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
group_id: entity.group_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
dialog_id: entity.dialog_id
}
// Independent weak flag仅标记弱关系不再维护 relations 聚合字段
SET e.is_weak = true
RETURN e.id AS id
"""
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true不维护 e.relations 字段
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
DIALOGUE_STATEMENT_EDGE_SAVE = """
UNWIND $dialogue_statement_edges AS edge
// 支持按 uuid 或 ref_id 连接到 Dialogue避免因来源 ID 不一致而断链
MATCH (dialogue:Dialogue)
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
MATCH (statement:Statement {id: edge.target})
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.group_id = edge.group_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
CHUNK_STATEMENT_EDGE_SAVE = """
UNWIND $chunk_statement_edges AS edge
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id})
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.group_id = edge.group_id,
e.run_id = edge.run_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.id AS uuid
"""
STATEMENT_ENTITY_EDGE_SAVE = """
UNWIND $relationships AS rel
// Statement nodes are per-run; keep run_id constraint on statements
// Statement nodes are per-run; keep run_id constraint on statements
MATCH (statement:Statement {id: rel.source, run_id: rel.run_id})
// Entities are shared across runs within a group; do not constrain by run_id
MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id})
// Avoid duplicate edges across runs for same endpoints
MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.group_id = rel.group_id,
r.run_id = rel.run_id,
r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.connect_strength = rel.connect_strength
RETURN elementId(r) AS uuid
"""
ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($group_id IS NULL OR e.group_id = $group_id)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.entity_type AS entity_type,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Statement.statement_embedding
STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($group_id IS NULL OR s.group_id = $group_id)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($group_id IS NULL OR c.group_id = $group_id)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
# 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($group_id IS NULL OR e.group_id = $group_id)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.group_id AS group_id,
e.entity_type AS entity_type,
e.apply_id AS apply_id,
e.user_id AS user_id,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
e.statement_id AS statement_id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
e.fact_summary AS fact_summary,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score
WHERE ($group_id IS NULL OR c.group_id = $group_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY score DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# # 同组group_id下按name contains召回补充
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND toLower(e.name) CONTAINS toLower(row.name)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE ($group_id IS NULL OR d.group_id = $group_id)
AND d.id = $dialog_id
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at
ORDER BY d.created_at DESC
LIMIT $limit
"""
SEARCH_CHUNK_BY_CHUNK_ID = """
MATCH (c:Chunk)
WHERE ($group_id IS NULL OR c.group_id = $group_id)
AND c.id = $chunk_id
RETURN c.id AS chunk_id,
c.group_id AS group_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.created_at AS created_at,
c.expired_at AS expired_at,
c.sequence_number AS sequence_number
ORDER BY c.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_TEMPORAL = """
MATCH (s:Statement)
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date))
AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date)))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
collect(DISTINCT s.id) AS statement_ids
ORDER BY datetime(s.created_at) DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score
WHERE ($group_id IS NULL OR s.group_id = $group_id)
AND ($apply_id IS NULL OR s.apply_id = $apply_id)
AND ($user_id IS NULL OR s.user_id = $user_id)
AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date)))
AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date))))
OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date)))
AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date)))))
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.group_id AS group_id,
s.apply_id AS apply_id,
s.user_id AS user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
score
ORDER BY s.created_at DESC, score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_G_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_L_CREATED_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.created_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_G_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_L_VALID_AT = """
MATCH (n:Statement)
WHERE ($group_id IS NULL OR n.group_id = $group_id)
AND ($apply_id IS NULL OR n.apply_id = $apply_id)
AND ($user_id IS NULL OR n.user_id = $user_id)
AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at))
RETURN n.id AS id,
n.statement AS statement,
n.group_id AS group_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
n.chunk_id AS chunk_id,
n.created_at AS created_at,
n.valid_at AS valid_at,
n.invalid_at AS invalid_at,
collect(DISTINCT n.id) AS statement_ids
ORDER BY n.valid_at DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name)))
# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# # 同组group_id下按name contains召回补充
# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """
# UNWIND $rows AS row
# MATCH (e:ExtractedEntity)
# WHERE e.group_id = row.group_id
# AND toLower(e.name) CONTAINS toLower(row.name)
# RETURN row.id AS incoming_id,
# e.id AS id,
# e.name AS name,
# e.group_id AS group_id,
# e.entity_idx AS entity_idx,
# e.entity_type AS entity_type,
# e.description AS description,
# e.statement_id AS statement_id,
# e.aliases AS aliases,
# e.name_embedding AS name_embedding,
# e.fact_summary AS fact_summary,
# e.connect_strength AS connect_strength,
# e.created_at AS created_at,
# e.expired_at AS expired_at
# """
# 根据id修改句子的invalid_at的值
UPDATE_STATEMENT_INVALID_AT = """
MATCH (n:Statement {group_id: $group_id, id: $id})
SET n.invalid_at = $new_invalid_at
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score
WHERE ($group_id IS NULL OR m.group_id = $group_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($group_id IS NULL OR m.group_id = $group_id)
RETURN m.id AS id,
m.name AS name,
m.group_id AS group_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
score
ORDER BY score DESC
LIMIT $limit
"""
MEMORY_SUMMARY_NODE_SAVE = """
UNWIND $summaries AS summary
MERGE (m:MemorySummary {id: summary.id})
SET m += {
id: summary.id,
name: summary.name,
group_id: summary.group_id,
user_id: summary.user_id,
apply_id: summary.apply_id,
run_id: summary.run_id,
created_at: summary.created_at,
expired_at: summary.expired_at,
dialog_id: summary.dialog_id,
chunk_ids: summary.chunk_ids,
content: summary.content,
summary_embedding: summary.summary_embedding,
config_id: summary.config_id
}
RETURN m.id AS uuid
"""
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE = """
UNWIND $edges AS e
MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id})
MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id})
MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.group_id = e.group_id,
r.run_id = e.run_id,
r.created_at = e.created_at,
r.expired_at = e.expired_at
RETURN elementId(r) AS uuid
"""

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
"""对话仓储模块
本模块提供对话节点的数据访问功能。
Classes:
DialogRepository: 对话仓储管理DialogueNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import DialogueNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"Dialogue"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化对话仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "Dialogue")
def _map_to_entity(self, node_data: Dict) -> DialogueNode:
"""将节点数据映射为对话实体
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
DialogueNode: 对话实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
Args:
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"user_id": user_id}, limit=limit)
async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]:
"""根据ref_id查询对话
ref_id是外部对话系统的引用ID通常是唯一的。
Args:
ref_id: 引用ID
Returns:
Optional[DialogueNode]: 找到的对话如果不存在则返回None
"""
results = await self.find({"ref_id": ref_id}, limit=1)
return results[0] if results else None
async def find_by_group_and_user(
self,
group_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
Args:
group_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表,按创建时间倒序排列
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
days=days,
limit=limit
)
return [self._map_to_entity(r) for r in results]
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id查询对话
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"config_id": config_id, "group_id": group_id},
limit=limit
)

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""实体仓储模块
本模块提供实体节点的数据访问功能。
Classes:
EntityRepository: 实体仓储管理ExtractedEntityNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import ExtractedEntityNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
"""实体仓储
管理实体节点的创建、查询、更新和删除操作。
提供按类型、名称、向量相似度等条件查询实体的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"ExtractedEntity"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化实体仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "ExtractedEntity")
def _map_to_entity(self, node_data: Dict) -> ExtractedEntityNode:
"""将节点数据映射为实体对象
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
ExtractedEntityNode: 实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return ExtractedEntityNode(**n)
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
"""根据实体类型查询
Args:
entity_type: 实体类型(如"Person", "Organization"等)
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"entity_type": entity_type}, limit=limit)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]:
"""根据group_id查询实体
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def find_by_name(
self,
name: str,
group_id: Optional[str] = None,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""根据名称查询实体
支持模糊匹配CONTAINS
Args:
name: 实体名称
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
where_clause = "n.name CONTAINS $name"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
RETURN n
LIMIT $limit
"""
params = {"name": name, "limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_related_entities(
self,
entity_id: str,
relation_type: Optional[str] = None,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""查询相关实体
查询与指定实体有关系的其他实体。
Args:
entity_id: 实体ID
relation_type: 可选的关系类型过滤
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 相关实体列表
"""
if relation_type:
query = """
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity)
RETURN e2 as n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
entity_id=entity_id,
relation_type=relation_type,
limit=limit
)
else:
query = """
MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity)
RETURN e2 as n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
entity_id=entity_id,
limit=limit
)
return [self._map_to_entity(r) for r in results]
async def search_by_embedding(
self,
embedding: List[float],
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索实体
使用余弦相似度计算查询向量与实体名称向量的相似度。
Args:
embedding: 查询向量
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含实体和相似度分数的字典列表
每个字典包含: entity (ExtractedEntityNode), score (float)
"""
where_clause = "n.name_embedding IS NOT NULL"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [
{
"entity": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]
async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]:
"""根据陈述句ID查询实体
查询从指定陈述句中提取的所有实体。
Args:
statement_id: 陈述句ID
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"statement_id": statement_id})
async def find_strong_entities(
self,
group_id: str,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""查询强连接的实体
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 强连接的实体列表
"""
return await self.find(
{"group_id": group_id, "connect_strength": "Strong"},
limit=limit
)
async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]:
"""统计各类型实体的数量
Args:
group_id: 组ID
Returns:
Dict[str, int]: 实体类型到数量的映射
"""
query = """
MATCH (n:ExtractedEntity {group_id: $group_id})
RETURN n.entity_type as entity_type, count(n) as count
ORDER BY count DESC
"""
results = await self.connector.execute_query(query, group_id=group_id)
return {r["entity_type"]: r["count"] for r in results}
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[ExtractedEntityNode]:
"""根据config_id查询实体
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[ExtractedEntityNode]: 实体列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def search_by_embedding_with_config(
self,
embedding: List[float],
config_id: Optional[str] = None,
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索实体,可选择按config_id过滤
使用余弦相似度计算查询向量与实体名称向量的相似度。
支持按config_id过滤结果,确保只返回使用特定配置处理的实体。
Args:
embedding: 查询向量
config_id: 可选的配置ID过滤
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含实体和相似度分数的字典列表
每个字典包含: entity (ExtractedEntityNode), score (float)
"""
# 构建查询条件
where_clauses = ["n.name_embedding IS NOT NULL"]
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if config_id:
where_clauses.append("n.config_id = $config_id")
params["config_id"] = config_id
if group_id:
where_clauses.append("n.group_id = $group_id")
params["group_id"] = group_id
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [
{
"entity": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]

View File

@@ -0,0 +1,216 @@
from typing import List
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes
from app.repositories.neo4j.cypher_queries import (
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
CHUNK_STATEMENT_EDGE_SAVE,
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
)
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
ExtractedEntityNode,
EntityEntityEdge,
)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector
):
"""Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes]
all_relationships = []
for edge in entity_entity_edges:
relationship = {
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'group_id': edge.group_id,
'user_id': edge.user_id,
'apply_id': edge.apply_id,
}
all_relationships.append(relationship)
# Save entities
if all_entities:
entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities)
if entity_uuids:
print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
else:
print("Failed to save entity nodes to Neo4j")
else:
print("No entity nodes to save")
# Create relationships
if all_relationships:
relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships)
if relationship_uuids:
print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j")
else:
print("Failed to save entity relationships to Neo4j")
else:
print("No entity relationships to save")
async def save_chunk_nodes(
chunk_nodes: List[ChunkNode],
connector: Neo4jConnector
):
"""Save chunk nodes using graph models"""
if not chunk_nodes:
print("No chunk nodes to save")
return
chunk_uuids = await add_chunk_nodes(chunk_nodes, connector)
if chunk_uuids:
print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
else:
print("Failed to save chunk nodes to Neo4j")
async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector
):
"""Save statement-chunk edges using graph models"""
if not statement_chunk_edges:
return
all_sc_edges = []
for edge in statement_chunk_edges:
all_sc_edges.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
})
try:
await connector.execute_query(
CHUNK_STATEMENT_EDGE_SAVE,
chunk_statement_edges=all_sc_edges
)
except Exception:
pass
async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
):
"""Save statement-entity edges using graph models"""
if not statement_entity_edges:
print("No statement-entity edges to save")
return
all_se_edges = []
for edge in statement_entity_edges:
edge_data = {
"source": edge.source,
"target": edge.target,
"group_id": edge.group_id,
"user_id": edge.user_id,
"apply_id": edge.apply_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
}
all_se_edges.append(edge_data)
if all_se_edges:
try:
await connector.execute_query(
STATEMENT_ENTITY_EDGE_SAVE,
relationships=all_se_edges
)
except Exception:
pass
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
Args:
dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save
entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save
connector: Neo4j connector instance
Returns:
bool: True if successful, False otherwise
"""
try:
# Save all dialogue nodes in batch
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
if dialogue_uuids:
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector)
# Save all statement nodes in batch
if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector)
if statement_uuids:
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("Failed to save statement nodes to Neo4j")
return False
else:
print("No statement nodes to save")
# Save entities and relationships
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
print("Successfully saved entities and relationships to Neo4j")
# Save new edges
await save_statement_chunk_edges(statement_chunk_edges, connector)
await save_statement_entity_edges(statement_entity_edges, connector)
return True
except Exception as e:
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False

View File

@@ -0,0 +1,584 @@
from typing import Any, Dict, List, Optional
import asyncio
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import (
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_ENTITIES_BY_NAME,
SEARCH_CHUNKS_BY_CONTENT,
STATEMENT_EMBEDDING_SEARCH,
CHUNK_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_STATEMENTS_BY_TEMPORAL,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_VALID_AT,
SEARCH_STATEMENTS_G_CREATED_AT,
SEARCH_STATEMENTS_L_CREATED_AT,
SEARCH_STATEMENTS_G_VALID_AT,
SEARCH_STATEMENTS_L_VALID_AT,
)
async def search_graph(
connector: Neo4jConnector,
q: str,
group_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
- Statements: matches s.statement CONTAINS q
- Entities: matches e.name CONTAINS q
- Chunks: matches s.content CONTAINS q (from Statement nodes)
- Summaries: matches ms.content CONTAINS q
Args:
connector: Neo4j connector
q: Query text
group_id: Optional group filter
limit: Max results per category
include: List of categories to search (default: all)
Returns:
Dictionary with search results per category
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
q=q,
group_id=group_id,
limit=limit,
))
task_keys.append("summaries")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary
results = {}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
results[key] = []
else:
results[key] = result
return results
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
group_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
- Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher
- Filters by group_id if provided
- Returns up to 'limit' per included type
"""
import time
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
embedding = embeddings[0]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
# Statements (embedding)
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("statements")
# Chunks (embedding)
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("chunks")
# Entities
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("entities")
# Memory summaries
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
embedding=embedding,
group_id=group_id,
limit=limit,
))
task_keys.append("summaries")
# Execute all queries in parallel
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
"statements": [],
"chunks": [],
"entities": [],
"summaries": [],
}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
results[key] = []
else:
results[key] = result
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
group_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
- 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选;
- 保留并发控制与返回结构incoming_id -> [db_entity_props...]
- 若提供 `entity_type`,在本地对返回结果做类型过滤;
- `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。
返回incoming_id -> [db_entity_props...]
"""
if not entities:
return {}
sem = asyncio.Semaphore(max_concurrency)
async def _query_by_name(incoming: Dict[str, Any]) -> tuple[str, List[Dict[str, Any]]]:
async with sem:
inc_id = incoming.get("id") or "__unknown__"
name = (incoming.get("name") or "").strip()
if not name:
return inc_id, []
try:
# 全文索引按名称检索(包含 CONTAINS 语义)
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name,
group_id=group_id,
limit=100,
)
except Exception:
rows = []
# 可选本地类型过滤(若输入实体提供类型)
typ = incoming.get("entity_type")
if typ:
try:
rows = [r for r in rows if (r.get("entity_type") == typ)]
except Exception:
pass
# 注入 incoming_id 以保持兼容下游合并逻辑
for r in rows:
r["incoming_id"] = inc_id
# 简单的降级:若为空且允许 fallback可按小写名再次查询
if use_contains_fallback and not rows and name:
try:
rows = await connector.execute_query(
SEARCH_ENTITIES_BY_NAME,
q=name.lower(),
group_id=group_id,
limit=100,
)
for r in rows:
r["incoming_id"] = inc_id
except Exception:
pass
return inc_id, rows
tasks = [_query_by_name(e) for e in entities]
results = await asyncio.gather(*tasks, return_exceptions=True)
merged: Dict[str, List[Dict[str, Any]]] = {}
for res in results:
if isinstance(res, Exception):
# 静默跳过单条失败
continue
inc_id, rows = res
inc_id = inc_id or "__unknown__"
merged.setdefault(inc_id, [])
existing_ids = {x.get("id") for x in merged[inc_id]}
for rec in rows:
if rec.get("id") not in existing_ids:
merged[inc_id].append(rec)
return merged
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
- Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
q=query_text,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_temporal(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_TEMPORAL,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
valid_date=valid_date,
invalid_date=invalid_date,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
group_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
- Matches dialogues with dialog_id
- Optionally filters by group_id
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
SEARCH_DIALOGUE_BY_DIALOG_ID,
group_id=group_id,
dialog_id=dialog_id,
limit=limit,
)
return {"dialogues": dialogues}
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
group_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
group_id=group_id,
chunk_id=chunk_id,
limit=limit,
)
return {"chunks": chunks}
async def search_graph_by_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_by_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_g_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_g_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_l_created_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}
async def search_graph_l_valid_at(
connector: Neo4jConnector,
group_id: Optional[str] = None,
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
- Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id
- Returns up to 'limit' statements
"""
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
group_id=group_id,
apply_id=apply_id,
user_id=user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
return {"statements": statements}

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""Neo4j连接器模块
本模块提供Neo4j图数据库的连接和查询功能。
从 app/core/memory/src/database/neo4j_connector.py 迁移而来。
Classes:
Neo4jConnector: Neo4j数据库连接器提供异步查询接口
"""
import os
from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth
from app.core.config import settings
class Neo4jConnector:
"""Neo4j数据库连接器
提供与Neo4j图数据库的连接和查询功能。
使用异步驱动程序以支持高并发操作。
Attributes:
driver: Neo4j异步驱动程序实例
Methods:
close: 关闭数据库连接
execute_query: 执行Cypher查询
delete_group: 删除指定组的所有数据
"""
def __init__(self):
"""初始化Neo4j连接器
从配置文件和环境变量中读取连接信息。
Raises:
RuntimeError: 如果NEO4J_PASSWORD环境变量未设置
"""
# 从全局配置和环境变量获取 Neo4j 配置
uri = settings.NEO4J_URI
username = settings.NEO4J_USERNAME
password = settings.NEO4J_PASSWORD
if not password:
raise RuntimeError(
"NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running."
)
self.driver = AsyncGraphDatabase.driver(
uri,
auth=basic_auth(username, password)
)
async def close(self):
"""关闭数据库连接
释放数据库连接资源。应在应用程序关闭时调用。
"""
await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
query: Cypher查询语句
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example:
>>> connector = Neo4jConnector()
>>> results = await connector.execute_query(
... "MATCH (n:Person {name: $name}) RETURN n",
... name="Alice"
... )
"""
result = await self.driver.execute_query(
query,
database="neo4j",
**kwargs
)
records, summary, keys = result
return [record.data() for record in records]
async def delete_group(self, group_id: str):
"""删除指定组的所有数据
删除所有属于指定group_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
group_id: 要删除的组ID
Example:
>>> connector = Neo4jConnector()
>>> await connector.delete_group("group_123")
Group group_123 deleted.
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n",
database="neo4j",
group_id=group_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r",
database="neo4j",
group_id=group_id
)
print(f"Group {group_id} deleted.")

View File

@@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
"""陈述句仓储模块
本模块提供陈述句节点的数据访问功能。
Classes:
StatementRepository: 陈述句仓储管理StatementNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import StatementNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.data.ontology import TemporalInfo
class StatementRepository(BaseNeo4jRepository[StatementNode]):
"""陈述句仓储
管理陈述句节点的创建、查询、更新和删除操作。
提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"Statement"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化陈述句仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "Statement")
def _map_to_entity(self, node_data: Dict) -> StatementNode:
"""将节点数据映射为陈述句实体
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
StatementNode: 陈述句实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
if n.get('valid_at') and isinstance(n['valid_at'], str):
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
# 处理temporal_info字段
if isinstance(n.get('temporal_info'), dict):
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
elif not n.get('temporal_info'):
# 如果没有temporal_info创建一个默认的
n['temporal_info'] = TemporalInfo()
return StatementNode(**n)
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
"""根据chunk_id查询陈述句
Args:
chunk_id: 分块ID
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"chunk_id": chunk_id})
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]:
"""根据group_id查询陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def search_by_embedding(
self,
embedding: List[float],
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句
使用余弦相似度计算查询向量与陈述句向量的相似度。
Args:
embedding: 查询向量
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clause = "n.statement_embedding IS NOT NULL"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]
async def search_by_keyword(
self,
keyword: str,
group_id: Optional[str] = None,
limit: int = 50
) -> List[StatementNode]:
"""基于关键词搜索陈述句
Args:
keyword: 搜索关键词
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clause = "n.statement CONTAINS $keyword"
if group_id:
where_clause += " AND n.group_id = $group_id"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_clause}
RETURN n
LIMIT $limit
"""
params = {"keyword": keyword, "limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_by_temporal_range(
self,
group_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100
) -> List[StatementNode]:
"""根据时间范围查询陈述句
查询在指定时间范围内有效的陈述句。
Args:
group_id: 组ID
start_date: 开始日期(可选)
end_date: 结束日期(可选)
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
where_clauses = ["n.group_id = $group_id"]
params = {"group_id": group_id, "limit": limit}
if start_date:
where_clauses.append("n.valid_at >= $start_date")
params["start_date"] = start_date.isoformat()
if end_date:
where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)")
params["end_date"] = end_date.isoformat()
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_entity(r) for r in results]
async def find_strong_statements(
self,
group_id: str,
limit: int = 100
) -> List[StatementNode]:
"""查询强连接的陈述句
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 强连接的陈述句列表
"""
return await self.find(
{"group_id": group_id, "connect_strength": "Strong"},
limit=limit
)
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[StatementNode]:
"""根据config_id查询陈述句
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[StatementNode]: 陈述句列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def search_by_embedding_with_config(
self,
embedding: List[float],
config_id: Optional[str] = None,
group_id: Optional[str] = None,
limit: int = 10,
min_score: float = 0.7
) -> List[Dict]:
"""基于向量相似度搜索陈述句,可选择按config_id过滤
使用余弦相似度计算查询向量与陈述句向量的相似度。
支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。
Args:
embedding: 查询向量
config_id: 可选的配置ID过滤
group_id: 可选的组ID过滤
limit: 返回结果的最大数量
min_score: 最小相似度分数阈值
Returns:
List[Dict]: 包含陈述句和相似度分数的字典列表
每个字典包含: statement (StatementNode), score (float)
"""
# 构建查询条件
where_clauses = ["n.statement_embedding IS NOT NULL"]
params = {
"embedding": embedding,
"min_score": min_score,
"limit": limit
}
if config_id:
where_clauses.append("n.config_id = $config_id")
params["config_id"] = config_id
if group_id:
where_clauses.append("n.group_id = $group_id")
params["group_id"] = group_id
where_str = " AND ".join(where_clauses)
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score
WHERE score > $min_score
RETURN n, score
ORDER BY score DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [
{
"statement": self._map_to_entity(r),
"score": r.get("score", 0.0)
}
for r in results
]

View File

@@ -0,0 +1,59 @@
import uuid
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy import select
from app.models import ReleaseShare
class ReleaseShareRepository:
"""发布版本分享仓储"""
def __init__(self, db: Session):
self.db = db
def create(self, release_share: ReleaseShare) -> ReleaseShare:
"""创建分享配置"""
self.db.add(release_share)
self.db.commit()
self.db.refresh(release_share)
return release_share
def get_by_id(self, share_id: uuid.UUID) -> Optional[ReleaseShare]:
"""根据 ID 获取分享配置"""
return self.db.get(ReleaseShare, share_id)
def get_by_release_id(self, release_id: uuid.UUID) -> Optional[ReleaseShare]:
"""根据发布版本 ID 获取分享配置"""
stmt = select(ReleaseShare).where(ReleaseShare.release_id == release_id)
return self.db.scalars(stmt).first()
def get_by_share_token(self, share_token: str) -> Optional[ReleaseShare]:
"""根据分享 token 获取分享配置"""
stmt = select(ReleaseShare).where(ReleaseShare.share_token == share_token)
return self.db.scalars(stmt).first()
def update(self, release_share: ReleaseShare) -> ReleaseShare:
"""更新分享配置"""
self.db.commit()
self.db.refresh(release_share)
return release_share
def delete(self, release_share: ReleaseShare) -> None:
"""删除分享配置"""
self.db.delete(release_share)
self.db.commit()
def token_exists(self, share_token: str) -> bool:
"""检查 token 是否已存在"""
stmt = select(ReleaseShare.id).where(ReleaseShare.share_token == share_token)
return self.db.scalars(stmt).first() is not None
def increment_view_count(self, share_id: uuid.UUID) -> None:
"""增加访问次数(异步更新,不阻塞)"""
from datetime import datetime
stmt = select(ReleaseShare).where(ReleaseShare.id == share_id)
share = self.db.scalars(stmt).first()
if share:
share.view_count += 1
share.last_accessed_at = datetime.now()
self.db.commit()

View File

@@ -0,0 +1,167 @@
import uuid
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, func
from typing import List, Optional
from app.models.tenant_model import Tenants
from app.models.user_model import User
from app.schemas.tenant_schema import TenantCreate, TenantUpdate
class TenantRepository:
"""租户数据访问层"""
def __init__(self, db: Session):
self.db = db
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
"""创建租户"""
db_tenant = Tenants(
name=tenant_data.name,
id=uuid.uuid4(),
description=tenant_data.description,
is_active=tenant_data.is_active
)
self.db.add(db_tenant)
self.db.flush()
return db_tenant
def get_tenant_by_id(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
"""根据ID获取租户"""
return self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
"""根据名称获取租户"""
return self.db.query(Tenants).filter(Tenants.name == name).first()
def get_tenants(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> List[Tenants]:
"""获取租户列表"""
query = self.db.query(Tenants)
if is_active is not None:
query = query.filter(Tenants.is_active == is_active)
if search:
query = query.filter(
or_(
Tenants.name.ilike(f"%{search}%"),
Tenants.description.ilike(f"%{search}%")
)
)
return query.offset(skip).limit(limit).all()
def count_tenants(
self,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户数量"""
query = self.db.query(func.count(Tenants.id))
if is_active is not None:
query = query.filter(Tenants.is_active == is_active)
if search:
query = query.filter(
or_(
Tenants.name.ilike(f"%{search}%"),
Tenants.description.ilike(f"%{search}%")
)
)
return query.scalar()
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
"""更新租户"""
db_tenant = self.get_tenant_by_id(tenant_id)
if not db_tenant:
return None
for field, value in tenant_data.dict(exclude_unset=True).items():
setattr(db_tenant, field, value)
self.db.flush()
return db_tenant
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
"""删除租户"""
db_tenant = self.get_tenant_by_id(tenant_id)
if not db_tenant:
return False
self.db.delete(db_tenant)
return True
def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]:
"""获取租户下的所有用户"""
query = self.db.query(User).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.all()
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
"""获取用户所属的租户"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user or not user.tenant_id:
return None
return self.get_tenant_by_id(user.tenant_id)
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""将用户分配给租户"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
return False
# 验证租户存在
tenant = self.get_tenant_by_id(tenant_id)
if not tenant:
return False
user.tenant_id = tenant_id
self.db.flush()
return True
def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int:
"""统计租户下的用户数量"""
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.scalar()
# 便利函数,保持向后兼容
def create_tenant(db: Session, tenant_data: TenantCreate) -> Tenants:
"""创建租户"""
return TenantRepository(db).create_tenant(tenant_data)
def get_tenant_by_id(db: Session, tenant_id: uuid.UUID) -> Optional[Tenants]:
"""根据ID获取租户"""
return TenantRepository(db).get_tenant_by_id(tenant_id)
def get_tenant_by_name(db: Session, name: str) -> Optional[Tenants]:
"""根据名称获取租户"""
return TenantRepository(db).get_tenant_by_name(name)
def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]:
"""获取租户列表"""
return TenantRepository(db).get_tenants(skip=skip, limit=limit)
def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]:
"""获取用户所属的租户"""
return TenantRepository(db).get_user_tenant(user_id)
def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]:
"""获取租户下的所有用户"""
return TenantRepository(db).get_tenant_users(tenant_id)

View File

@@ -0,0 +1,322 @@
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, func
from typing import List, Optional
import uuid
from app.models.user_model import User
from app.models.tenant_model import Tenants
from app.schemas.user_schema import UserCreate, UserUpdate
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class UserRepository:
"""用户数据访问层"""
def __init__(self, db: Session):
self.db = db
def get_user_by_id(self, user_id: uuid.UUID) -> Optional[User]:
"""根据ID获取用户"""
db_logger.debug(f"根据ID查询用户: user_id={user_id}")
try:
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id).first()
if user:
db_logger.debug(f"用户查询成功: {user.username} (ID: {user_id})")
else:
db_logger.debug(f"用户不存在: user_id={user_id}")
return user
except Exception as e:
db_logger.error(f"根据ID查询用户失败: user_id={user_id} - {str(e)}")
raise
def get_user_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
db_logger.debug(f"根据邮箱查询用户: email={email}")
try:
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.email == email).first()
if user:
db_logger.debug(f"用户查询成功: {user.username} (email: {email})")
else:
db_logger.debug(f"用户不存在: email={email}")
return user
except Exception as e:
db_logger.error(f"根据邮箱查询用户失败: email={email} - {str(e)}")
raise
def get_user_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
db_logger.debug(f"根据用户名查询用户: username={username}")
try:
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.username == username).first()
if user:
db_logger.debug(f"用户查询成功: {user.username} (ID: {user.id})")
else:
db_logger.debug(f"用户不存在: username={username}")
return user
except Exception as e:
db_logger.error(f"根据用户名查询用户失败: username={username} - {str(e)}")
raise
def get_superuser(self) -> Optional[User]:
"""获取超级用户"""
db_logger.debug("查询超级用户")
try:
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).first()
if user:
db_logger.debug(f"超级用户查询成功: {user.username}")
else:
db_logger.debug("超级用户不存在")
return user
except Exception as e:
db_logger.error(f"查询超级用户失败: {str(e)}")
raise
def check_superuser_only(self) -> bool:
"""检查是否只有一个超级用户"""
db_logger.debug("检查是否只有一个超级用户")
try:
count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).count()
return count == 1
except Exception as e:
db_logger.error(f"检查超级用户数量失败: {str(e)}")
raise
def create_user(
self,
user_data: UserCreate,
hashed_password: str,
tenant_id: Optional[uuid.UUID] = None,
is_superuser: bool = False
) -> User:
"""创建用户"""
db_logger.debug(f"创建用户记录: username={user_data.username}, email={user_data.email}, is_superuser={is_superuser}")
try:
db_user = User(
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
tenant_id=tenant_id,
is_superuser=is_superuser,
)
self.db.add(db_user)
self.db.flush()
db_logger.info(f"用户记录创建成功: {user_data.username} (email: {user_data.email})")
return db_user
except Exception as e:
db_logger.error(f"创建用户记录失败: username={user_data.username} - {str(e)}")
raise
def update_user(self, user_id: uuid.UUID, user_data: UserUpdate) -> Optional[User]:
"""更新用户"""
db_logger.debug(f"更新用户: user_id={user_id}")
try:
user = self.get_user_by_id(user_id)
if not user:
db_logger.debug(f"用户不存在: user_id={user_id}")
return None
for field, value in user_data.dict(exclude_unset=True).items():
setattr(user, field, value)
self.db.flush()
db_logger.info(f"用户更新成功: {user.username}")
return user
except Exception as e:
db_logger.error(f"更新用户失败: user_id={user_id} - {str(e)}")
raise
def delete_user(self, user_id: uuid.UUID) -> bool:
"""删除用户"""
db_logger.debug(f"删除用户: user_id={user_id}")
try:
user = self.get_user_by_id(user_id)
if not user:
db_logger.debug(f"用户不存在: user_id={user_id}")
return False
self.db.delete(user)
self.db.flush()
db_logger.info(f"用户删除成功: {user.username}")
return True
except Exception as e:
db_logger.error(f"删除用户失败: user_id={user_id} - {str(e)}")
raise
def get_users_by_tenant(
self,
tenant_id: uuid.UUID,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> List[User]:
"""获取租户下的用户列表"""
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
try:
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if search:
query = query.filter(
or_(
User.username.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%")
)
)
users = query.offset(skip).limit(limit).all()
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
return users
except Exception as e:
db_logger.error(f"查询租户用户失败: tenant_id={tenant_id} - {str(e)}")
raise
def count_users_by_tenant(
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户下的用户数量"""
try:
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if search:
query = query.filter(
or_(
User.username.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%")
)
)
return query.scalar()
except Exception as e:
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
raise
def get_superusers_by_tenant(
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = True
) -> List[User]:
"""获取租户下的超管用户列表"""
db_logger.debug(f"查询租户超管用户: tenant_id={tenant_id}")
try:
query = self.db.query(User).options(joinedload(User.tenant)).filter(
and_(
User.tenant_id == tenant_id,
User.is_superuser == True
)
)
if is_active is not None:
query = query.filter(User.is_active == is_active)
users = query.all()
db_logger.debug(f"租户超管用户查询成功: tenant_id={tenant_id}, count={len(users)}")
return users
except Exception as e:
db_logger.error(f"查询租户超管用户失败: tenant_id={tenant_id} - {str(e)}")
raise
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""将用户分配给租户"""
db_logger.debug(f"分配用户到租户: user_id={user_id}, tenant_id={tenant_id}")
try:
user = self.get_user_by_id(user_id)
if not user:
db_logger.debug(f"用户不存在: user_id={user_id}")
return False
# 验证租户存在
tenant = self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
if not tenant:
db_logger.debug(f"租户不存在: tenant_id={tenant_id}")
return False
user.tenant_id = tenant_id
self.db.flush()
db_logger.info(f"用户分配成功: user={user.username}, tenant={tenant.name}")
return True
except Exception as e:
db_logger.error(f"分配用户到租户失败: user_id={user_id}, tenant_id={tenant_id} - {str(e)}")
raise
def get_users_without_tenant(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> List[User]:
"""获取没有租户的用户列表"""
try:
query = self.db.query(User).filter(User.tenant_id.is_(None))
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.offset(skip).limit(limit).all()
except Exception as e:
db_logger.error(f"查询无租户用户失败: {str(e)}")
raise
# 便利函数,保持向后兼容
def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]:
"""根据ID获取用户"""
return UserRepository(db).get_user_by_id(user_id)
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return UserRepository(db).get_user_by_email(email)
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return UserRepository(db).get_user_by_username(username)
def get_superuser(db: Session) -> Optional[User]:
"""获取超级用户"""
return UserRepository(db).get_superuser()
def check_superuser_only(db: Session) -> Optional[User]:
"""检查是否只有一个超级用户"""
return UserRepository(db).check_superuser_only()
def create_user(
db: Session,
user: UserCreate,
hashed_password: str,
tenant_id: Optional[uuid.UUID] = None,
is_superuser: bool = False
) -> User:
"""创建用户(函数式接口)"""
repo = UserRepository(db)
return repo.create_user(user, hashed_password, tenant_id, is_superuser)
def get_superusers_by_tenant(
db: Session,
tenant_id: uuid.UUID,
is_active: Optional[bool] = True
) -> List[User]:
"""获取租户下的超管用户列表(函数式接口)"""
repo = UserRepository(db)
return repo.get_superusers_by_tenant(tenant_id, is_active)

View File

@@ -0,0 +1,134 @@
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from typing import List, Optional
import datetime
import uuid
from app.models.workspace_model import WorkspaceInvite, InviteStatus
from app.schemas.workspace_schema import WorkspaceInviteCreate
class WorkspaceInviteRepository:
def __init__(self, db: Session):
self.db = db
def create_invite(
self,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
token_hash: str,
created_by_user_id: uuid.UUID
) -> WorkspaceInvite:
"""创建工作空间邀请"""
expires_at = datetime.datetime.now() + datetime.timedelta(days=invite_data.expires_in_days)
db_invite = WorkspaceInvite(
workspace_id=workspace_id,
email=invite_data.email,
role=invite_data.role,
token_hash=token_hash,
status=InviteStatus.pending,
expires_at=expires_at,
created_by_user_id=created_by_user_id
)
self.db.add(db_invite)
self.db.commit()
self.db.refresh(db_invite)
return db_invite
def get_invite_by_token_hash(self, token_hash: str) -> Optional[WorkspaceInvite]:
"""根据令牌哈希获取邀请"""
return self.db.query(WorkspaceInvite).filter(
WorkspaceInvite.token_hash == token_hash
).first()
def get_invite_by_id(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]:
"""根据ID获取邀请"""
return self.db.query(WorkspaceInvite).filter(
WorkspaceInvite.id == invite_id
).first()
def get_workspace_invites(
self,
workspace_id: uuid.UUID,
status: Optional[InviteStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[WorkspaceInvite]:
"""获取工作空间的邀请列表"""
query = self.db.query(WorkspaceInvite).filter(
WorkspaceInvite.workspace_id == workspace_id
)
if status:
query = query.filter(WorkspaceInvite.status == status)
return query.order_by(WorkspaceInvite.created_at.desc()).offset(offset).limit(limit).all()
def get_pending_invite_by_email_and_workspace(
self,
email: str,
workspace_id: uuid.UUID
) -> Optional[WorkspaceInvite]:
"""获取指定邮箱在指定工作空间的待处理邀请"""
return self.db.query(WorkspaceInvite).filter(
and_(
WorkspaceInvite.email == email,
WorkspaceInvite.workspace_id == workspace_id,
WorkspaceInvite.status == InviteStatus.pending
)
).first()
def update_invite_status(
self,
invite_id: uuid.UUID,
status: InviteStatus,
accepted_at: Optional[datetime.datetime] = None
) -> Optional[WorkspaceInvite]:
"""更新邀请状态"""
invite = self.get_invite_by_id(invite_id)
if invite:
invite.status = status
if accepted_at:
invite.accepted_at = accepted_at
invite.updated_at = datetime.datetime.now()
self.db.commit()
self.db.refresh(invite)
return invite
def revoke_invite(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]:
"""撤销邀请"""
return self.update_invite_status(invite_id, InviteStatus.revoked)
def expire_old_invites(self) -> int:
"""将过期的邀请标记为已过期"""
now = datetime.datetime.now()
expired_count = self.db.query(WorkspaceInvite).filter(
and_(
WorkspaceInvite.status == InviteStatus.pending,
WorkspaceInvite.expires_at < now
)
).update(
{
WorkspaceInvite.status: InviteStatus.expired,
WorkspaceInvite.updated_at: now
}
)
self.db.commit()
return expired_count
def count_workspace_invites(
self,
workspace_id: uuid.UUID,
status: Optional[InviteStatus] = None
) -> int:
"""统计工作空间邀请数量"""
query = self.db.query(WorkspaceInvite).filter(
WorkspaceInvite.workspace_id == workspace_id
)
if status:
query = query.filter(WorkspaceInvite.status == status)
return query.count()

View File

@@ -0,0 +1,383 @@
from sqlalchemy.orm import Session, joinedload
from app.models.user_model import User
from typing import List, Optional
import uuid
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class WorkspaceRepository:
"""工作空间数据访问层"""
def __init__(self, db: Session):
self.db = db
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
"""创建工作空间"""
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
try:
db_workspace = Workspace(
name=workspace_data.name,
description=workspace_data.description,
icon=workspace_data.icon,
iconType=workspace_data.iconType,
storage_type=workspace_data.storage_type,
llm=workspace_data.llm,
embedding=workspace_data.embedding,
rerank=workspace_data.rerank,
tenant_id=tenant_id
)
self.db.add(db_workspace)
self.db.flush()
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
return db_workspace
except Exception as e:
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
raise
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
"""根据ID获取工作空间"""
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
db_logger.debug(f"工作空间查询成功: {workspace.name} (ID: {workspace_id})")
else:
db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}")
return workspace
except Exception as e:
db_logger.error(f"根据ID查询工作空间失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_models_configs(self, workspace_id: uuid.UUID) -> Optional[dict]:
"""根据workspace_id获取模型配置llm, embedding, rerank
Args:
workspace_id: 工作空间ID
Returns:
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
"""
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
configs = {
"llm": workspace.llm,
"embedding": workspace.embedding,
"rerank": workspace.rerank
}
db_logger.debug(
f"工作空间模型配置查询成功: workspace_id={workspace_id}, "
f"llm={configs['llm']}, embedding={configs['embedding']}, rerank={configs['rerank']}"
)
return configs
else:
db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}")
return None
except Exception as e:
db_logger.error(f"查询工作空间模型配置失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
try:
# 首先获取用户信息以获取 tenant_id
from app.models.user_model import User
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
db_logger.warning(f"用户不存在: user_id={user_id}")
return []
if user.is_superuser:
# 超级用户获取对应tenantid所有工作空间
workspaces = (
self.db.query(Workspace)
.filter(Workspace.tenant_id == user.tenant_id)
.filter(Workspace.is_active == True)
.order_by(Workspace.updated_at.desc())
.all()
)
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
return workspaces
# 获取用户作为成员的工作空间
member_workspaces = (
self.db.query(Workspace)
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
.filter(WorkspaceMember.user_id == user_id)
.filter(Workspace.is_active == True)
.order_by(Workspace.updated_at.desc())
.all()
)
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
return member_workspaces
except Exception as e:
db_logger.error(f"查询用户工作空间失败: user_id={user_id} - {str(e)}")
raise
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
"""获取租户的所有工作空间"""
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
try:
workspaces = (
self.db.query(Workspace)
.filter(Workspace.tenant_id == tenant_id)
.filter(Workspace.is_active == True)
.all()
)
db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}")
return workspaces
except Exception as e:
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
raise
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
"""添加工作空间成员"""
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
try:
db_member = WorkspaceMember(
user_id=user_id,
workspace_id=workspace_id,
role=role
)
self.db.add(db_member)
self.db.flush()
db_logger.info(f"工作空间成员添加成功: user_id={user_id}, workspace_id={workspace_id}, role={role}")
return db_member
except Exception as e:
db_logger.error(f"添加工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}")
raise
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
"""获取工作空间成员"""
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.user_id == user_id,
WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.is_active == True,
).first()
if member:
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
else:
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
return member
except Exception as e:
db_logger.error(f"查询工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}")
raise
def get_members_by_workspace(self, workspace_id: uuid.UUID) -> List[WorkspaceMember]:
"""按工作空间获取成员列表,并预加载 user 与 workspace 关系"""
db_logger.debug(f"查询工作空间的成员列表: workspace_id={workspace_id}")
try:
members = (
self.db.query(WorkspaceMember)
.join(User, WorkspaceMember.user_id == User.id)
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
.filter(WorkspaceMember.workspace_id == workspace_id)
.filter(WorkspaceMember.is_active == True)
.filter(User.is_active == True)
.all()
)
db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}")
return members
except Exception as e:
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
"""按成员ID获取工作空间成员并预加载 user 与 workspace 关系"""
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
try:
member = (
self.db.query(WorkspaceMember)
.join(User, WorkspaceMember.user_id == User.id)
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
.filter(WorkspaceMember.id == member_id)
.filter(WorkspaceMember.is_active == True)
.filter(User.is_active == True)
.first()
)
if member:
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
else:
db_logger.debug(f"成员不存在: member_id={member_id}")
return member
except Exception as e:
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
raise
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.user_id == user_id,
WorkspaceMember.is_active == True,
).first()
if not member:
return None
member.role = role
self.db.commit()
self.db.refresh(member)
return member
except Exception as e:
db_logger.error(f"更新成员角色失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
raise
def deactivate_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.user_id == user_id,
WorkspaceMember.is_active == True,
).first()
if not member:
return None
member.is_active = False
self.db.commit()
self.db.refresh(member)
return member
except Exception as e:
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
raise
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.id == member_id,
WorkspaceMember.is_active == True,
).first()
if not member:
return None
member.is_active = False
self.db.commit()
self.db.refresh(member)
return member
except Exception as e:
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
raise
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.id == id,
WorkspaceMember.is_active == True,
).first()
if not member:
return None
member.role = role
self.db.commit()
self.db.refresh(member)
return member
except Exception as e:
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
raise
# 保持向后兼容的函数
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
repo = WorkspaceRepository(db)
return repo.get_workspace_by_id(workspace_id)
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
repo = WorkspaceRepository(db)
return repo.get_workspaces_by_user(user_id)
def get_workspaces_by_tenant(db: Session, tenant_id: uuid.UUID) -> List[Workspace]:
repo = WorkspaceRepository(db)
return repo.get_workspaces_by_tenant(tenant_id)
def get_member_in_workspace(db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID) -> WorkspaceMember | None:
repo = WorkspaceRepository(db)
return repo.get_member(user_id, workspace_id)
def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
repo = WorkspaceRepository(db)
return repo.create_workspace(workspace, tenant_id)
def add_member_to_workspace(
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
) -> WorkspaceMember:
repo = WorkspaceRepository(db)
return repo.add_member(workspace_id, user_id, role)
def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.get_members_by_workspace(workspace_id)
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
repo = WorkspaceRepository(db)
return repo.get_member_by_id(member_id)
def update_member_role_in_workspace(
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
role: WorkspaceRole,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.update_member_role(workspace_id, user_id, role)
def remove_member_from_workspace(
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.deactivate_member(workspace_id, user_id)
def remove_member_from_workspace_by_id(
db: Session,
member_id: uuid.UUID,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.delete_member_by_id(member_id)
def update_member_role_by_id(
db: Session,
id: uuid.UUID,
role: WorkspaceRole,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.update_member_role_by_id(id, role)
def get_workspace_models_configs(db: Session, workspace_id: uuid.UUID) -> Optional[dict]:
"""根据workspace_id获取模型配置llm, embedding, rerank
Args:
db: 数据库会话
workspace_id: 工作空间ID
Returns:
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
Example:
>>> configs = get_workspace_models_configs(db, workspace_id)
>>> if configs:
>>> print(f"LLM: {configs['llm']}")
>>> print(f"Embedding: {configs['embedding']}")
>>> print(f"Rerank: {configs['rerank']}")
"""
repo = WorkspaceRepository(db)
return repo.get_workspace_models_configs(workspace_id)