feat(prompt_opt): support streaming output for prompt optimization API

This commit is contained in:
mengyonghao
2026-01-05 10:53:53 +08:00
parent fc831e04c1
commit eaf2437633
4 changed files with 75 additions and 44 deletions

View File

@@ -1,5 +1,6 @@
import re
import uuid
from typing import Any, AsyncGenerator
import json_repair
from langchain_core.prompts import ChatPromptTemplate
@@ -123,7 +124,7 @@ class PromptOptimizerService:
user_id: uuid.UUID,
current_prompt: str,
user_require: str
) -> OptimizePromptResult:
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -161,6 +162,7 @@ class PromptOptimizerService:
BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
@@ -202,17 +204,54 @@ class PromptOptimizerService:
messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)])
logger.info(f"Prompt optimization message: {messages}")
optim_resp = await llm.ainvoke(messages)
logger.info(optim_resp.content)
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
buffer = ""
prompt_started = False
prompt_finished = False
idx = 0
return OptimizePromptResult(
prompt=prompt,
desc=desc
async for chunk in llm.astream(messages):
content = getattr(chunk, "content", chunk)
if not content:
continue
buffer += content
cache = buffer[:-20]
# 尝试找到 "prompt": " 开始位置
if prompt_finished:
continue
if not prompt_started:
m = re.search(r'"prompt"\s*:\s*"', cache)
if m:
prompt_started = True
prompt_index = m.end()
idx = prompt_index
else:
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
if m:
prompt_index = m.start()
prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]}
else:
yield {"type": "delta", "content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
# optim_resp = await llm.astream(messages)
logger.info(buffer)
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
self.create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=RoleType.ASSISTANT,
content=desc
)
yield {"type": "done", "desc": optim_result.get("desc")}
@staticmethod
def parser_prompt_variables(prompt: str):
try: