Files
MemoryBear/api/app/services/prompt_optimizer_service.py
Timebomb2018 590ec3a446 feat(model and app):
1. Increase support for visual models and multimodal models;
2. The application and workflow can input various multimodal files such as images, documents, audio, and videos.
2026-03-05 09:55:54 +08:00

474 lines
18 KiB
Python

import os
import re
import uuid
from typing import Any, AsyncGenerator
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
from app.models.prompt_optimizer_model import (
PromptOptimizerSession,
RoleType
)
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerSessionRepository,
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
class PromptOptimizerService:
def __init__(self, db: Session):
self.db = db
self.optim_repo = PromptOptimizerSessionRepository(self.db)
self.release_repo = PromptReleaseRepository(self.db)
def get_model_config(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID
) -> ModelConfig:
"""
Retrieve the model configuration for a specific tenant.
This method fetches the model configuration associated with the given
tenant_id and model_id. If no configuration is found, a BusinessException
is raised.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the model.
Returns:
ModelConfig: The corresponding model configuration object.
Raises:
BusinessException: If the model configuration does not exist.
"""
model = ModelConfigRepository.get_by_id(
self.db, model_id, tenant_id=tenant_id
)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
def create_session(
self,
tenant_id: uuid.UUID,
user_id: uuid.UUID
) -> PromptOptimizerSession:
"""
Create a new prompt optimization session.
This method initializes a new prompt optimization session for the specified
tenant, application, and user, and persists it to the database.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
PromptOptimzerSession: The newly created prompt optimization session.
"""
session = self.optim_repo.create_session(
tenant_id=tenant_id,
user_id=user_id
)
self.db.commit()
self.db.refresh(session)
return session
def get_session_message_history(
self,
session_id: uuid.UUID,
user_id: uuid.UUID
) -> list[tuple[str, str]]:
"""
Retrieve the chronological message history for a prompt optimization session.
This method queries the database to fetch all messages associated with a
specific prompt optimization session for a given user. Messages are returned
in chronological order and typically include both user inputs and
model-generated responses.
Args:
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
user_id (uuid.UUID): The unique identifier of the user associated with the session.
Returns:
list[tuple[str, str]]: A list of tuples representing messages. Each tuple contains:
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
- content (str): The content of the message.
"""
history = self.optim_repo.get_session_history(
session_id=session_id,
user_id=user_id
)
messages = []
for message in history:
messages.append((message.role, message.content))
return messages
async def optimize_prompt(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
user_require: str,
skill: bool = False
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
This method refines the original prompt according to the user's requirements,
generating an optimized version that is directly usable by AI tools. The
optimization process follows strict rules, including:
- Wrapping user-inserted variables in double curly braces {{}}.
- Adhering to Jinja2 variable syntax if applicable.
- Ensuring a clear logic flow, explicit instructions, and strong executability.
- Producing output in a strict JSON format.
Steps performed:
1. Retrieve the model configuration for the given tenant and model.
2. Fetch the session message history for context.
3. Instantiate the LLM with the appropriate API key and model configuration.
4. Build system messages outlining optimization rules.
5. Format the user's original prompt and requirements as a user message.
6. Send messages to the LLM to generate the optimized prompt.
7. Generate a concise description summarizing the changes made during optimization.
Args:
tenant_id (uuid.UUID): Tenant identifier.
model_id (uuid.UUID): Prompt optimizer model identifier.
session_id (uuid.UUID): Prompt optimization session identifier.
user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization.
skill(bool): Is skill required
Returns:
OptimizePromptResult: An object containing:
- prompt: The optimized prompt string.
- desc: A short description summarizing the changes.
Raises:
BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
# api_config: ModelApiKey = api_keys[0] if api_keys else None
api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id)
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base,
is_omni=api_config.is_omni
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
except Exception as e:
logger.error(f"Failed to load system prompt template: {e}")
raise BusinessException(message="Internal server error", code=BizCode.INTERNAL_ERROR)
rendered_user_message = Template(opt_user_prompt).render(
current_prompt=current_prompt,
user_require=user_require
)
# build message
messages = [
# init system_prompt
(
RoleType.SYSTEM.value,
rendered_system_message
),
]
messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)])
buffer = ""
prompt_started = False
prompt_finished = False
idx = 0
async for chunk in llm.astream(messages):
content = getattr(chunk, "content", chunk)
if not content:
continue
buffer += content
cache = buffer[:-20]
# 尝试找到 "prompt": " 开始位置
if prompt_finished:
continue
if not prompt_started:
m = re.search(r'"prompt"\s*:\s*"', cache)
if m:
prompt_started = True
prompt_index = m.end()
idx = prompt_index
else:
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
if m:
prompt_index = m.start()
prompt_finished = True
yield {"content": buffer[idx:prompt_index]}
else:
yield {"content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
# optim_resp = await llm.astream(messages)
logger.info(buffer)
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
ModelApiKeyService.record_api_key_usage(self.db, api_config.id)
self.create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=RoleType.ASSISTANT,
content=desc
)
variables = self.parser_prompt_variables(optim_result.get("prompt"))
logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}")
yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod
def parser_prompt_variables(prompt: str):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
matches = re.findall(pattern, prompt)
variables = list(set(matches))
return variables
except Exception as e:
logger.error(f"Failed to parse prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to parse prompt variables", BizCode.PARSER_NOT_SUPPORTED)
@staticmethod
def fill_prompt_variables(prompt: str, variables: dict[str, str]):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
def replace_var(match):
var_name = match.group(1)
return variables.get(var_name, match.group(0))
result = re.sub(pattern, replace_var, prompt)
return result
except Exception as e:
logger.error(f"Failed to fill prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to fill prompt variables", BizCode.PARSER_NOT_SUPPORTED)
def create_message(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
role: RoleType,
content: str
) -> PromptOptimizerSessionHistory:
"""Insert Message to Session History"""
message = PromptOptimizerSessionRepository(self.db).create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=role,
content=content
)
self.db.commit()
self.db.refresh(message)
return message
def save_prompt(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
title: str,
prompt: str
) -> dict:
"""
Create and save a new prompt release for a given session.
Args:
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
session_id (uuid.UUID): The ID of the session to associate with this prompt.
title (str): The title of the prompt release.
prompt (str): The content of the prompt.
Returns:
dict: A dictionary containing:
- id (UUID): The unique ID of the created prompt release.
- session_id (UUID): The session ID linked to the release.
- title (str): The title of the prompt.
- prompt (str): The prompt content.
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
Raises:
BusinessException: If a prompt release already exists for the given session.
"""
session = self.optim_repo.get_session_by_id(session_id)
if session is None or session.tenant_id != tenant_id:
raise BusinessException(
"Session does not exist or the current user has no access",
BizCode.BAD_REQUEST
)
if self.release_repo.get_prompt_by_session_id(session_id):
raise BusinessException(
"A release already exists for the current session",
BizCode.BAD_REQUEST
)
prompt_obj = self.release_repo.create_prompt_release(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt
)
self.db.commit()
self.db.refresh(prompt_obj)
return {
"id": prompt_obj.id,
"session_id": prompt_obj.session_id,
"title": prompt_obj.title,
"prompt": prompt_obj.prompt,
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
}
def delete_prompt(
self,
tenant_id: uuid.UUID,
prompt_id: uuid.UUID
) -> None:
"""
Soft delete a prompt release by prompt_id.
Args:
tenant_id (uuid.UUID): Tenant identifier.
prompt_id (uuid.UUID): Prompt identifier.
Raises:
BusinessException: If the prompt does not exist or already deleted.
"""
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
if not prompt_obj or prompt_obj.is_delete:
raise BusinessException(
"Prompt does not exist or has already been deleted",
BizCode.NOT_FOUND
)
if prompt_obj.tenant_id != tenant_id:
raise BusinessException(
"No permission to delete this prompt",
BizCode.FORBIDDEN
)
self.release_repo.soft_delete_prompt(prompt_obj)
self.db.commit()
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
def get_release_list(
self,
tenant_id: uuid.UUID,
page: int,
page_size: int,
filter_keyword: str | None = None
) -> dict[str, int | list[Any]]:
"""
Get paginated list of prompt releases with optional filter.
Args:
tenant_id (uuid.UUID): Tenant identifier.
page (int): Page number (starting from 1).
page_size (int): Number of items per page.
filter_keyword (str | None): Optional keyword to filter by title.
Returns:
dict: Contains total count, pagination info, and list of releases.
"""
offset = (page - 1) * page_size
# Get total count and releases based on filter
if filter_keyword:
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
releases = self.release_repo.search_prompts_paginated(
tenant_id=tenant_id,
keyword=filter_keyword,
offset=offset,
limit=page_size
)
else:
total = self.release_repo.count_prompts(tenant_id)
releases = self.release_repo.get_prompts_paginated(
tenant_id=tenant_id,
offset=offset,
limit=page_size
)
items = []
for release in releases:
# Get first user message from session
first_message = self.optim_repo.get_first_user_message(
session_id=release.session_id
)
items.append({
"id": release.id,
"title": release.title,
"prompt": release.prompt,
"created_at": int(release.created_at.timestamp() * 1000),
"first_message": first_message
})
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
if filter_keyword:
log_msg += f", filter='{filter_keyword}'"
logger.info(log_msg)
result = {
"page": {
"total": total,
"page": page,
"page_size": page_size,
"hasnext": page * page_size < total
},
"keyword": filter_keyword,
"items": items
}
return result