322 lines
11 KiB
Python
322 lines
11 KiB
Python
import uuid
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.logging_config import get_db_logger
|
|
from app.models.prompt_optimizer_model import (
|
|
PromptOptimizerSession,
|
|
PromptOptimizerSessionHistory,
|
|
RoleType,
|
|
PromptHistory
|
|
)
|
|
|
|
db_logger = get_db_logger()
|
|
|
|
|
|
class PromptOptimizerSessionRepository:
|
|
"""Repository for managing prompt optimization sessions and session history."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
|
|
session = self.db.query(PromptOptimizerSession).filter(
|
|
PromptOptimizerSession.id == session_id,
|
|
).first()
|
|
return session
|
|
|
|
def create_session(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
user_id: uuid.UUID
|
|
) -> PromptOptimizerSession:
|
|
"""
|
|
Create a new prompt optimization session for a user and app.
|
|
|
|
Args:
|
|
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
|
user_id (uuid.UUID): The unique identifier of the user.
|
|
|
|
Returns:
|
|
PromptOptimizerSession: The newly created session object.
|
|
"""
|
|
db_logger.debug(f"Create prompt optimization session: tenant_id={tenant_id}, user_id={user_id}")
|
|
try:
|
|
session = PromptOptimizerSession(
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
)
|
|
self.db.add(session)
|
|
return session
|
|
except Exception as e:
|
|
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
|
|
raise
|
|
|
|
def get_session_history(
|
|
self,
|
|
session_id: uuid.UUID,
|
|
user_id: uuid.UUID
|
|
) -> list[type[PromptOptimizerSessionHistory]]:
|
|
"""
|
|
Retrieve all message history of a specific prompt optimization session.
|
|
|
|
Args:
|
|
session_id (uuid.UUID): The unique identifier of the session.
|
|
user_id (uuid.UUID): The unique identifier of the user.
|
|
|
|
Returns:
|
|
list[PromptOptimizerSessionHistory]: A list of session history records
|
|
ordered by creation time ascending.
|
|
"""
|
|
db_logger.debug(f"Get prompt optimization session history: "
|
|
f"user_id={user_id}, session_id={session_id}")
|
|
|
|
try:
|
|
# First get the internal session ID from the session list table
|
|
session = self.db.query(PromptOptimizerSession).filter(
|
|
PromptOptimizerSession.id == session_id,
|
|
PromptOptimizerSession.user_id == user_id
|
|
).first()
|
|
|
|
if not session:
|
|
return []
|
|
|
|
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
|
PromptOptimizerSessionHistory.session_id == session.id,
|
|
PromptOptimizerSessionHistory.user_id == user_id
|
|
).order_by(PromptOptimizerSessionHistory.created_at.asc()).all()
|
|
return history
|
|
except Exception as e:
|
|
db_logger.error(f"Error retrieving prompt optimization session history: session_id={session_id} - {str(e)}")
|
|
raise
|
|
|
|
def create_message(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
session_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
role: RoleType,
|
|
content: str,
|
|
) -> PromptOptimizerSessionHistory:
|
|
"""
|
|
Create a new message in the session history.
|
|
|
|
This method is a placeholder for future implementation.
|
|
"""
|
|
try:
|
|
# Get the session to ensure it exists and belongs to the user
|
|
session = self.db.query(PromptOptimizerSession).filter(
|
|
PromptOptimizerSession.id == session_id,
|
|
PromptOptimizerSession.user_id == user_id,
|
|
PromptOptimizerSession.tenant_id == tenant_id
|
|
).first()
|
|
|
|
if not session:
|
|
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
|
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
|
|
|
message = PromptOptimizerSessionHistory(
|
|
tenant_id=tenant_id,
|
|
session_id=session.id,
|
|
user_id=user_id,
|
|
role=role.value,
|
|
content=content,
|
|
)
|
|
self.db.add(message)
|
|
|
|
return message
|
|
except Exception as e:
|
|
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
|
raise
|
|
|
|
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
|
|
"""
|
|
Get the first user message from a session.
|
|
|
|
Args:
|
|
session_id (uuid.UUID): The session ID.
|
|
|
|
Returns:
|
|
str | None: The content of the first user message, or None if not found.
|
|
"""
|
|
try:
|
|
message = self.db.query(PromptOptimizerSessionHistory).filter(
|
|
PromptOptimizerSessionHistory.session_id == session_id,
|
|
PromptOptimizerSessionHistory.role == RoleType.USER.value
|
|
).order_by(
|
|
PromptOptimizerSessionHistory.created_at.asc()
|
|
).first()
|
|
|
|
return message.content if message else None
|
|
except Exception as e:
|
|
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
|
|
raise
|
|
|
|
|
|
class PromptReleaseRepository:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
|
|
prompt_obj = self.db.query(PromptHistory).filter(
|
|
PromptHistory.session_id == session_id,
|
|
PromptHistory.is_delete.is_(False)
|
|
).first()
|
|
return prompt_obj
|
|
|
|
def create_prompt_release(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
title: str,
|
|
session_id: uuid.UUID,
|
|
prompt: str,
|
|
) -> PromptHistory:
|
|
try:
|
|
prompt_obj = PromptHistory(
|
|
tenant_id=tenant_id,
|
|
title=title,
|
|
session_id=session_id,
|
|
prompt=prompt,
|
|
)
|
|
self.db.add(prompt_obj)
|
|
return prompt_obj
|
|
except Exception as e:
|
|
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
|
|
raise
|
|
|
|
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
|
|
"""
|
|
Soft delete a prompt release by setting is_delete flag to True.
|
|
|
|
Args:
|
|
prompt_obj (PromptHistory): The prompt release object to delete.
|
|
"""
|
|
try:
|
|
prompt_obj.is_delete = True
|
|
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
|
|
except Exception as e:
|
|
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
|
|
raise
|
|
|
|
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
|
|
"""
|
|
Get a prompt release by its ID.
|
|
|
|
Args:
|
|
prompt_id (uuid.UUID): The prompt release ID.
|
|
|
|
Returns:
|
|
PromptHistory | None: The prompt release object or None if not found.
|
|
"""
|
|
try:
|
|
prompt_obj = self.db.query(PromptHistory).filter(
|
|
PromptHistory.id == prompt_id
|
|
).first()
|
|
return prompt_obj
|
|
except Exception as e:
|
|
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
|
|
raise
|
|
|
|
def count_prompts(self, tenant_id: uuid.UUID) -> int:
|
|
"""
|
|
Count total number of non-deleted prompts for a tenant.
|
|
|
|
Args:
|
|
tenant_id (uuid.UUID): The tenant ID.
|
|
|
|
Returns:
|
|
int: Total count of prompts.
|
|
"""
|
|
try:
|
|
count = self.db.query(PromptHistory).filter(
|
|
PromptHistory.tenant_id == tenant_id,
|
|
PromptHistory.is_delete.is_(False)
|
|
).count()
|
|
return count
|
|
except Exception as e:
|
|
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
|
|
raise
|
|
|
|
def get_prompts_paginated(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
offset: int,
|
|
limit: int
|
|
) -> list[PromptHistory]:
|
|
"""
|
|
Get paginated list of prompt releases for a tenant.
|
|
|
|
Args:
|
|
tenant_id (uuid.UUID): The tenant ID.
|
|
offset (int): Number of records to skip.
|
|
limit (int): Maximum number of records to return.
|
|
|
|
Returns:
|
|
list[PromptHistory]: List of prompt releases.
|
|
"""
|
|
try:
|
|
prompts = self.db.query(PromptHistory).filter(
|
|
PromptHistory.tenant_id == tenant_id,
|
|
PromptHistory.is_delete.is_(False)
|
|
).order_by(
|
|
PromptHistory.created_at.desc()
|
|
).offset(offset).limit(limit).all()
|
|
return prompts
|
|
except Exception as e:
|
|
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
|
|
raise
|
|
|
|
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
|
|
"""
|
|
Count total number of non-deleted prompts matching keyword for a tenant.
|
|
|
|
Args:
|
|
tenant_id (uuid.UUID): The tenant ID.
|
|
keyword (str): Search keyword for title.
|
|
|
|
Returns:
|
|
int: Total count of matching prompts.
|
|
"""
|
|
try:
|
|
count = self.db.query(PromptHistory).filter(
|
|
PromptHistory.tenant_id == tenant_id,
|
|
PromptHistory.is_delete.is_(False),
|
|
PromptHistory.title.ilike(f"%{keyword}%")
|
|
).count()
|
|
return count
|
|
except Exception as e:
|
|
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
|
raise
|
|
|
|
def search_prompts_paginated(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
keyword: str,
|
|
offset: int,
|
|
limit: int
|
|
) -> list[PromptHistory]:
|
|
"""
|
|
Search prompt releases by keyword in title with pagination.
|
|
|
|
Args:
|
|
tenant_id (uuid.UUID): The tenant ID.
|
|
keyword (str): Search keyword for title.
|
|
offset (int): Number of records to skip.
|
|
limit (int): Maximum number of records to return.
|
|
|
|
Returns:
|
|
list[PromptHistory]: List of matching prompt releases.
|
|
"""
|
|
try:
|
|
prompts = self.db.query(PromptHistory).filter(
|
|
PromptHistory.tenant_id == tenant_id,
|
|
PromptHistory.is_delete.is_(False),
|
|
PromptHistory.title.ilike(f"%{keyword}%")
|
|
).order_by(
|
|
PromptHistory.created_at.desc()
|
|
).offset(offset).limit(limit).all()
|
|
return prompts
|
|
except Exception as e:
|
|
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
|
raise
|