import os import re import uuid from typing import Any, AsyncGenerator import json_repair from jinja2 import Template from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearModelConfig from app.core.models.llm import RedBearLLM from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory from app.models.prompt_optimizer_model import ( PromptOptimizerSession, RoleType ) from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository from app.repositories.prompt_optimizer_repository import ( PromptOptimizerSessionRepository, PromptReleaseRepository ) from app.schemas.prompt_optimizer_schema import OptimizePromptResult from app.services.model_service import ModelApiKeyService logger = get_business_logger() class PromptOptimizerService: def __init__(self, db: Session): self.db = db self.optim_repo = PromptOptimizerSessionRepository(self.db) self.release_repo = PromptReleaseRepository(self.db) def get_model_config( self, tenant_id: uuid.UUID, model_id: uuid.UUID ) -> ModelConfig: """ Retrieve the model configuration for a specific tenant. This method fetches the model configuration associated with the given tenant_id and model_id. If no configuration is found, a BusinessException is raised. Args: tenant_id (uuid.UUID): The unique identifier of the tenant. model_id (uuid.UUID): The unique identifier of the model. Returns: ModelConfig: The corresponding model configuration object. Raises: BusinessException: If the model configuration does not exist. """ model = ModelConfigRepository.get_by_id( self.db, model_id, tenant_id=tenant_id ) if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return model def create_session( self, tenant_id: uuid.UUID, user_id: uuid.UUID ) -> PromptOptimizerSession: """ Create a new prompt optimization session. This method initializes a new prompt optimization session for the specified tenant, application, and user, and persists it to the database. Args: tenant_id (uuid.UUID): The unique identifier of the tenant. user_id (uuid.UUID): The unique identifier of the user. Returns: PromptOptimzerSession: The newly created prompt optimization session. """ session = self.optim_repo.create_session( tenant_id=tenant_id, user_id=user_id ) self.db.commit() self.db.refresh(session) return session def get_session_message_history( self, session_id: uuid.UUID, user_id: uuid.UUID ) -> list[tuple[str, str]]: """ Retrieve the chronological message history for a prompt optimization session. This method queries the database to fetch all messages associated with a specific prompt optimization session for a given user. Messages are returned in chronological order and typically include both user inputs and model-generated responses. Args: session_id (uuid.UUID): The unique identifier of the prompt optimization session. user_id (uuid.UUID): The unique identifier of the user associated with the session. Returns: list[tuple[str, str]]: A list of tuples representing messages. Each tuple contains: - role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'. - content (str): The content of the message. """ history = self.optim_repo.get_session_history( session_id=session_id, user_id=user_id ) messages = [] for message in history: messages.append((message.role, message.content)) return messages async def optimize_prompt( self, tenant_id: uuid.UUID, model_id: uuid.UUID, session_id: uuid.UUID, user_id: uuid.UUID, current_prompt: str, user_require: str, skill: bool = False ) -> AsyncGenerator[dict[str, str | Any], Any]: """ Optimize a user-provided prompt using a configured prompt optimizer LLM. This method refines the original prompt according to the user's requirements, generating an optimized version that is directly usable by AI tools. The optimization process follows strict rules, including: - Wrapping user-inserted variables in double curly braces {{}}. - Adhering to Jinja2 variable syntax if applicable. - Ensuring a clear logic flow, explicit instructions, and strong executability. - Producing output in a strict JSON format. Steps performed: 1. Retrieve the model configuration for the given tenant and model. 2. Fetch the session message history for context. 3. Instantiate the LLM with the appropriate API key and model configuration. 4. Build system messages outlining optimization rules. 5. Format the user's original prompt and requirements as a user message. 6. Send messages to the LLM to generate the optimized prompt. 7. Generate a concise description summarizing the changes made during optimization. Args: tenant_id (uuid.UUID): Tenant identifier. model_id (uuid.UUID): Prompt optimizer model identifier. session_id (uuid.UUID): Prompt optimization session identifier. user_id (uuid.UUID): Identifier of the user associated with the session. current_prompt (str): Original prompt to optimize. user_require (str): User's requirements or instructions for optimization. skill(bool): Is skill required Returns: OptimizePromptResult: An object containing: - prompt: The optimized prompt string. - desc: A short description summarizing the changes. Raises: BusinessException: If the LLM response cannot be parsed as valid JSON or does not conform to the expected output format. """ self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require) model_config = self.get_model_config(tenant_id, model_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}") # Create LLM instance # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id) # api_config: ModelApiKey = api_keys[0] if api_keys else None api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id) llm = RedBearLLM(RedBearModelConfig( model_name=api_config.model_name, provider=api_config.provider, api_key=api_config.api_key, base_url=api_config.api_base ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render(skill=skill) with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: opt_user_prompt = f.read() except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) except Exception as e: logger.error(f"Failed to load system prompt template: {e}") raise BusinessException(message="Internal server error", code=BizCode.INTERNAL_ERROR) rendered_user_message = Template(opt_user_prompt).render( current_prompt=current_prompt, user_require=user_require ) # build message messages = [ # init system_prompt ( RoleType.SYSTEM.value, rendered_system_message ), ] messages.extend(session_history[:-1]) # last message is current message messages.extend([(RoleType.USER.value, rendered_user_message)]) buffer = "" prompt_started = False prompt_finished = False idx = 0 async for chunk in llm.astream(messages): content = getattr(chunk, "content", chunk) if not content: continue buffer += content cache = buffer[:-20] # 尝试找到 "prompt": " 开始位置 if prompt_finished: continue if not prompt_started: m = re.search(r'"prompt"\s*:\s*"', cache) if m: prompt_started = True prompt_index = m.end() idx = prompt_index else: m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer) if m: prompt_index = m.start() prompt_finished = True yield {"content": buffer[idx:prompt_index]} else: yield {"content": cache[idx:]} if len(cache) != 0: idx = len(cache) # optim_resp = await llm.astream(messages) logger.info(buffer) optim_result = json_repair.repair_json(buffer, return_objects=True) # prompt = optim_result.get("prompt") desc = optim_result.get("desc") ModelApiKeyService.record_api_key_usage(self.db, api_config.id) self.create_message( tenant_id=tenant_id, session_id=session_id, user_id=user_id, role=RoleType.ASSISTANT, content=desc ) variables = self.parser_prompt_variables(optim_result.get("prompt")) logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}") yield {"desc": optim_result.get("desc"), "variables": variables} @staticmethod def parser_prompt_variables(prompt: str): try: pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}' matches = re.findall(pattern, prompt) variables = list(set(matches)) return variables except Exception as e: logger.error(f"Failed to parse prompt variables - Error: {str(e)}", exc_info=True) raise BusinessException("Failed to parse prompt variables", BizCode.PARSER_NOT_SUPPORTED) @staticmethod def fill_prompt_variables(prompt: str, variables: dict[str, str]): try: pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}' def replace_var(match): var_name = match.group(1) return variables.get(var_name, match.group(0)) result = re.sub(pattern, replace_var, prompt) return result except Exception as e: logger.error(f"Failed to fill prompt variables - Error: {str(e)}", exc_info=True) raise BusinessException("Failed to fill prompt variables", BizCode.PARSER_NOT_SUPPORTED) def create_message( self, tenant_id: uuid.UUID, session_id: uuid.UUID, user_id: uuid.UUID, role: RoleType, content: str ) -> PromptOptimizerSessionHistory: """Insert Message to Session History""" message = PromptOptimizerSessionRepository(self.db).create_message( tenant_id=tenant_id, session_id=session_id, user_id=user_id, role=role, content=content ) self.db.commit() self.db.refresh(message) return message def save_prompt( self, tenant_id: uuid.UUID, session_id: uuid.UUID, title: str, prompt: str ) -> dict: """ Create and save a new prompt release for a given session. Args: tenant_id (uuid.UUID): The ID of the tenant owning the prompt. session_id (uuid.UUID): The ID of the session to associate with this prompt. title (str): The title of the prompt release. prompt (str): The content of the prompt. Returns: dict: A dictionary containing: - id (UUID): The unique ID of the created prompt release. - session_id (UUID): The session ID linked to the release. - title (str): The title of the prompt. - prompt (str): The prompt content. - created_at (int): Timestamp (in milliseconds) of when the prompt was created. Raises: BusinessException: If a prompt release already exists for the given session. """ session = self.optim_repo.get_session_by_id(session_id) if session is None or session.tenant_id != tenant_id: raise BusinessException( "Session does not exist or the current user has no access", BizCode.BAD_REQUEST ) if self.release_repo.get_prompt_by_session_id(session_id): raise BusinessException( "A release already exists for the current session", BizCode.BAD_REQUEST ) prompt_obj = self.release_repo.create_prompt_release( tenant_id=tenant_id, title=title, session_id=session_id, prompt=prompt ) self.db.commit() self.db.refresh(prompt_obj) return { "id": prompt_obj.id, "session_id": prompt_obj.session_id, "title": prompt_obj.title, "prompt": prompt_obj.prompt, "created_at": int(prompt_obj.created_at.timestamp() * 1000) } def delete_prompt( self, tenant_id: uuid.UUID, prompt_id: uuid.UUID ) -> None: """ Soft delete a prompt release by prompt_id. Args: tenant_id (uuid.UUID): Tenant identifier. prompt_id (uuid.UUID): Prompt identifier. Raises: BusinessException: If the prompt does not exist or already deleted. """ prompt_obj = self.release_repo.get_prompt_by_id(prompt_id) if not prompt_obj or prompt_obj.is_delete: raise BusinessException( "Prompt does not exist or has already been deleted", BizCode.NOT_FOUND ) if prompt_obj.tenant_id != tenant_id: raise BusinessException( "No permission to delete this prompt", BizCode.FORBIDDEN ) self.release_repo.soft_delete_prompt(prompt_obj) self.db.commit() logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}") def get_release_list( self, tenant_id: uuid.UUID, page: int, page_size: int, filter_keyword: str | None = None ) -> dict[str, int | list[Any]]: """ Get paginated list of prompt releases with optional filter. Args: tenant_id (uuid.UUID): Tenant identifier. page (int): Page number (starting from 1). page_size (int): Number of items per page. filter_keyword (str | None): Optional keyword to filter by title. Returns: dict: Contains total count, pagination info, and list of releases. """ offset = (page - 1) * page_size # Get total count and releases based on filter if filter_keyword: total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword) releases = self.release_repo.search_prompts_paginated( tenant_id=tenant_id, keyword=filter_keyword, offset=offset, limit=page_size ) else: total = self.release_repo.count_prompts(tenant_id) releases = self.release_repo.get_prompts_paginated( tenant_id=tenant_id, offset=offset, limit=page_size ) items = [] for release in releases: # Get first user message from session first_message = self.optim_repo.get_first_user_message( session_id=release.session_id ) items.append({ "id": release.id, "title": release.title, "prompt": release.prompt, "created_at": int(release.created_at.timestamp() * 1000), "first_message": first_message }) log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}" if filter_keyword: log_msg += f", filter='{filter_keyword}'" logger.info(log_msg) result = { "page": { "total": total, "page": page, "page_size": page_size, "hasnext": page * page_size < total }, "keyword": filter_keyword, "items": items } return result