feat(prompt_opt): support streaming output for prompt optimization API
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import json_repair
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
@@ -123,7 +124,7 @@ class PromptOptimizerService:
|
||||
user_id: uuid.UUID,
|
||||
current_prompt: str,
|
||||
user_require: str
|
||||
) -> OptimizePromptResult:
|
||||
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||
"""
|
||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||
|
||||
@@ -161,6 +162,7 @@ class PromptOptimizerService:
|
||||
BusinessException: If the LLM response cannot be parsed as valid JSON
|
||||
or does not conform to the expected output format.
|
||||
"""
|
||||
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
|
||||
model_config = self.get_model_config(tenant_id, model_id)
|
||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||
|
||||
@@ -202,17 +204,54 @@ class PromptOptimizerService:
|
||||
messages.extend(session_history[:-1]) # last message is current message
|
||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||
logger.info(f"Prompt optimization message: {messages}")
|
||||
optim_resp = await llm.ainvoke(messages)
|
||||
logger.info(optim_resp.content)
|
||||
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
|
||||
prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
buffer = ""
|
||||
prompt_started = False
|
||||
prompt_finished = False
|
||||
idx = 0
|
||||
|
||||
return OptimizePromptResult(
|
||||
prompt=prompt,
|
||||
desc=desc
|
||||
async for chunk in llm.astream(messages):
|
||||
content = getattr(chunk, "content", chunk)
|
||||
if not content:
|
||||
continue
|
||||
buffer += content
|
||||
cache = buffer[:-20]
|
||||
|
||||
# 尝试找到 "prompt": " 开始位置
|
||||
if prompt_finished:
|
||||
continue
|
||||
|
||||
if not prompt_started:
|
||||
m = re.search(r'"prompt"\s*:\s*"', cache)
|
||||
if m:
|
||||
prompt_started = True
|
||||
prompt_index = m.end()
|
||||
idx = prompt_index
|
||||
else:
|
||||
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
|
||||
if m:
|
||||
prompt_index = m.start()
|
||||
prompt_finished = True
|
||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||
else:
|
||||
yield {"type": "delta", "content": cache[idx:]}
|
||||
if len(cache) != 0:
|
||||
idx = len(cache)
|
||||
|
||||
# optim_resp = await llm.astream(messages)
|
||||
logger.info(buffer)
|
||||
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||
# prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
self.create_message(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=desc
|
||||
)
|
||||
|
||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||
|
||||
@staticmethod
|
||||
def parser_prompt_variables(prompt: str):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user