perf(prompt_opt): improve prompt optimization and model output quality
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import json_repair
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from jinja2 import Template
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -170,68 +172,45 @@ class PromptOptimizerService:
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
with open('app/templates/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render()
|
||||
|
||||
with open('app/templates/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
opt_user_prompt = f.read()
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load system prompt template: {e}")
|
||||
raise BusinessException(message="Internal server error", code=BizCode.INTERNAL_ERROR)
|
||||
rendered_user_message = Template(opt_user_prompt).render(
|
||||
current_prompt=current_prompt,
|
||||
user_require=user_require
|
||||
)
|
||||
|
||||
# build message
|
||||
messages = [
|
||||
# init system_prompt
|
||||
(
|
||||
RoleType.SYSTEM.value,
|
||||
"Your task is to optimize the original prompt provided by the user so that it can be directly used by AI tools,"
|
||||
"and the variables that the user needs to insert must be wrapped in {{}}. "
|
||||
"The optimized prompt should align with the optimization direction specified by the user (if any) and ensure clear logic, explicit instructions, and strong executability. "
|
||||
"Please follow these rules when optimizing: "
|
||||
'1. Ensure variables are wrapped in {{}}, e.g., optimize "Please enter your question" to "Please enter your {{question}}"'
|
||||
"2. Instructions must be specific and operable, avoiding vague expressions"
|
||||
"3. If the original prompt lacks key elements (such as output format requirements), supplement them completely "
|
||||
"4. Keep the language concise and avoid redundancy "
|
||||
"5. If the user does not specify an optimization direction, the default optimization is to make the prompt structurally clear and with explicit instructions"
|
||||
"Please directly output the optimized prompt without additional explanations. The optimized prompt should be directly usable with correct variable positions."
|
||||
rendered_system_message
|
||||
),
|
||||
]
|
||||
|
||||
# base model limit
|
||||
(RoleType.SYSTEM.value,
|
||||
"Optimization Rules:\n"
|
||||
"1. Fully adjust the prompt content according to the user's requirements.\n"
|
||||
"When variables are required, use double curly braces {{variable_name}} as placeholders."
|
||||
"Variable names must be derived from the user's requirements.\n"
|
||||
"3. Keep the prompt logic clear and instructions explicit.\n"
|
||||
"4. Ensure that the modified prompt can be directly used.\n\n")
|
||||
]
|
||||
messages.extend(session_history[:-1]) # last message is current message
|
||||
user_message_template = ChatPromptTemplate.from_messages([
|
||||
(RoleType.USER.value, "[original_prompt]\n{current_prompt}\n[user_require]\n{user_require}")
|
||||
])
|
||||
formatted_user_message = user_message_template.format(current_prompt=current_prompt, user_require=user_require)
|
||||
messages.extend([(RoleType.USER.value, formatted_user_message)])
|
||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||
logger.info(f"Prompt optimization message: {messages}")
|
||||
optim_prompt = await llm.ainvoke(messages)
|
||||
optim_desc = [
|
||||
(
|
||||
RoleType.SYSTEM.value,
|
||||
"You are a prompt optimization assistant.\n"
|
||||
"Compare the original prompt, the user's requirements, "
|
||||
"and the optimized prompt.\n"
|
||||
"Summarize the changes made during optimization.\n\n"
|
||||
"Rules:\n"
|
||||
"1. Output must be a single short sentence.\n"
|
||||
"2. Be concise and factual.\n"
|
||||
"3. Do not explain the prompts themselves.\n"
|
||||
"4. Do not include any extra text."
|
||||
),
|
||||
(
|
||||
"[Original Prompt]\n"
|
||||
f"{current_prompt}\n\n"
|
||||
"[User Requirements]\n"
|
||||
f"{user_require}\n\n"
|
||||
"[Optimized Prompt]\n"
|
||||
f"{optim_prompt.content}"
|
||||
)
|
||||
]
|
||||
optim_desc = await llm.ainvoke(optim_desc)
|
||||
optim_resp = await llm.ainvoke(messages)
|
||||
logger.info(optim_resp.content)
|
||||
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
|
||||
prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
|
||||
return OptimizePromptResult(
|
||||
prompt=optim_prompt.content,
|
||||
desc=optim_desc.content
|
||||
prompt=prompt,
|
||||
desc=desc
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -253,6 +232,7 @@ class PromptOptimizerService:
|
||||
def replace_var(match):
|
||||
var_name = match.group(1)
|
||||
return variables.get(var_name, match.group(0))
|
||||
|
||||
result = re.sub(pattern, replace_var, prompt)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user