Files
MemoryBear/api/app/repositories/__init__.py

172 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""仓储模块
本模块提供统一的数据访问层包括PostgreSQL和Neo4j的仓储实现。
Classes:
RepositoryFactory: 仓储工厂,统一管理所有数据库的仓储实例
"""
from typing import Optional
from sqlalchemy.orm import Session
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.dialog_repository import DialogRepository
from app.repositories.neo4j.statement_repository import StatementRepository
from app.repositories.neo4j.entity_repository import EntityRepository
from app.repositories.user_repository import UserRepository
from app.repositories.workspace_repository import WorkspaceRepository
from app.repositories.app_repository import AppRepository
class RepositoryFactory:
"""仓储工厂 - 统一管理所有数据库的仓储
这个工厂类提供了获取各种仓储实例的统一接口。
支持Neo4j图数据库和PostgreSQL关系数据库的仓储。
Attributes:
neo4j_connector: Neo4j连接器实例可选
db_session: SQLAlchemy数据库会话可选
Example:
>>> # 创建工厂实例
>>> factory = RepositoryFactory(
... neo4j_connector=Neo4jConnector(),
... db_session=db_session
... )
>>>
>>> # 获取Neo4j仓储
>>> dialog_repo = factory.get_dialog_repository()
>>> statement_repo = factory.get_statement_repository()
>>>
>>> # 获取PostgreSQL仓储
>>> knowledge_repo = factory.get_knowledge_repository()
"""
def __init__(
self,
neo4j_connector: Optional[Neo4jConnector] = None,
db_session: Optional[Session] = None
):
"""初始化仓储工厂
Args:
neo4j_connector: Neo4j连接器实例可选
db_session: SQLAlchemy数据库会话可选
"""
self.neo4j_connector = neo4j_connector
self.db_session = db_session
# ==================== Neo4j 仓储 ====================
def get_dialog_repository(self) -> DialogRepository:
"""获取对话仓储
Returns:
DialogRepository: 对话仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return DialogRepository(self.neo4j_connector)
def get_statement_repository(self) -> StatementRepository:
"""获取陈述句仓储
Returns:
StatementRepository: 陈述句仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return StatementRepository(self.neo4j_connector)
def get_entity_repository(self) -> EntityRepository:
"""获取实体仓储
Returns:
EntityRepository: 实体仓储实例
Raises:
ValueError: 如果Neo4j连接器未初始化
"""
if not self.neo4j_connector:
raise ValueError("Neo4j connector not initialized")
return EntityRepository(self.neo4j_connector)
# ==================== PostgreSQL 仓储 ====================
# 注意现有的PostgreSQL仓储保持不变这里只是提供统一的访问接口
# 部分仓储如knowledge_repository、document_repository使用函数式接口
# 部分仓储如user_repository、workspace_repository使用类接口
def get_user_repository(self) -> UserRepository:
"""获取用户仓储
Returns:
UserRepository: 用户仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return UserRepository(self.db_session)
def get_workspace_repository(self) -> WorkspaceRepository:
"""获取工作空间仓储
Returns:
WorkspaceRepository: 工作空间仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return WorkspaceRepository(self.db_session)
def get_app_repository(self) -> AppRepository:
"""获取应用仓储
Returns:
AppRepository: 应用仓储实例
Raises:
ValueError: 如果数据库会话未初始化
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return AppRepository(self.db_session)
def get_db_session(self) -> Session:
"""获取数据库会话
用于访问函数式仓储如knowledge_repository、document_repository
Returns:
Session: SQLAlchemy数据库会话
Raises:
ValueError: 如果数据库会话未初始化
Example:
>>> factory = RepositoryFactory(db_session=session)
>>> db = factory.get_db_session()
>>> # 使用函数式仓储
>>> from app.repositories import knowledge_repository
>>> knowledges = knowledge_repository.get_knowledges_paginated(db, [], 1, 10)
"""
if not self.db_session:
raise ValueError("Database session not initialized")
return self.db_session
__all__ = [
'RepositoryFactory',
]