feat(prompt_opt): support streaming output for prompt optimization API
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
@@ -70,12 +72,12 @@ def get_prompt_session(
|
||||
SessionMessage(role=role, content=content)
|
||||
for role, content in history
|
||||
]
|
||||
|
||||
|
||||
result = SessionHistoryResponse(
|
||||
session_id=session_id,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@@ -104,35 +106,25 @@ async def get_prompt_opt(
|
||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.USER,
|
||||
content=data.message
|
||||
)
|
||||
opt_result = await service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=opt_result.desc
|
||||
)
|
||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
||||
result = {
|
||||
"prompt": opt_result.prompt,
|
||||
"desc": opt_result.desc,
|
||||
"variables": variables
|
||||
}
|
||||
result_schema = OptimizePromptResponse.model_validate(result)
|
||||
return success(data=result_schema)
|
||||
|
||||
async def event_generator():
|
||||
async for chunk in service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
从文件系统加载预定义的工作流模板
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
"""工作流模板加载器"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import json_repair
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
@@ -123,7 +124,7 @@ class PromptOptimizerService:
|
||||
user_id: uuid.UUID,
|
||||
current_prompt: str,
|
||||
user_require: str
|
||||
) -> OptimizePromptResult:
|
||||
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||
"""
|
||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||
|
||||
@@ -161,6 +162,7 @@ class PromptOptimizerService:
|
||||
BusinessException: If the LLM response cannot be parsed as valid JSON
|
||||
or does not conform to the expected output format.
|
||||
"""
|
||||
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
|
||||
model_config = self.get_model_config(tenant_id, model_id)
|
||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||
|
||||
@@ -202,17 +204,54 @@ class PromptOptimizerService:
|
||||
messages.extend(session_history[:-1]) # last message is current message
|
||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||
logger.info(f"Prompt optimization message: {messages}")
|
||||
optim_resp = await llm.ainvoke(messages)
|
||||
logger.info(optim_resp.content)
|
||||
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
|
||||
prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
buffer = ""
|
||||
prompt_started = False
|
||||
prompt_finished = False
|
||||
idx = 0
|
||||
|
||||
return OptimizePromptResult(
|
||||
prompt=prompt,
|
||||
desc=desc
|
||||
async for chunk in llm.astream(messages):
|
||||
content = getattr(chunk, "content", chunk)
|
||||
if not content:
|
||||
continue
|
||||
buffer += content
|
||||
cache = buffer[:-20]
|
||||
|
||||
# 尝试找到 "prompt": " 开始位置
|
||||
if prompt_finished:
|
||||
continue
|
||||
|
||||
if not prompt_started:
|
||||
m = re.search(r'"prompt"\s*:\s*"', cache)
|
||||
if m:
|
||||
prompt_started = True
|
||||
prompt_index = m.end()
|
||||
idx = prompt_index
|
||||
else:
|
||||
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
|
||||
if m:
|
||||
prompt_index = m.start()
|
||||
prompt_finished = True
|
||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||
else:
|
||||
yield {"type": "delta", "content": cache[idx:]}
|
||||
if len(cache) != 0:
|
||||
idx = len(cache)
|
||||
|
||||
# optim_resp = await llm.astream(messages)
|
||||
logger.info(buffer)
|
||||
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||
# prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
self.create_message(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=desc
|
||||
)
|
||||
|
||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||
|
||||
@staticmethod
|
||||
def parser_prompt_variables(prompt: str):
|
||||
try:
|
||||
|
||||
@@ -25,7 +25,7 @@ Rules
|
||||
Basic Principles
|
||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||
Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels.
|
||||
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
||||
Language Rule: All label languages must fully match the user input language.
|
||||
|
||||
Behavior Guidelines
|
||||
|
||||
Reference in New Issue
Block a user