feat(prompt_opt): support streaming output for prompt optimization API
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Path
|
from fastapi import APIRouter, Depends, Path
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -104,35 +106,25 @@ async def get_prompt_opt(
|
|||||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||||
"""
|
"""
|
||||||
service = PromptOptimizerService(db)
|
service = PromptOptimizerService(db)
|
||||||
service.create_message(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
role=RoleType.USER,
|
|
||||||
content=data.message
|
|
||||||
)
|
|
||||||
opt_result = await service.optimize_prompt(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
model_id=data.model_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
current_prompt=data.current_prompt,
|
|
||||||
user_require=data.message
|
|
||||||
)
|
|
||||||
service.create_message(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
role=RoleType.ASSISTANT,
|
|
||||||
content=opt_result.desc
|
|
||||||
)
|
|
||||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
|
||||||
result = {
|
|
||||||
"prompt": opt_result.prompt,
|
|
||||||
"desc": opt_result.desc,
|
|
||||||
"variables": variables
|
|
||||||
}
|
|
||||||
result_schema = OptimizePromptResponse.model_validate(result)
|
|
||||||
return success(data=result_schema)
|
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
async for chunk in service.optimize_prompt(
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
model_id=data.model_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_prompt=data.current_prompt,
|
||||||
|
user_require=data.message
|
||||||
|
):
|
||||||
|
# chunk 是 prompt 的增量内容
|
||||||
|
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@
|
|||||||
从文件系统加载预定义的工作流模板
|
从文件系统加载预定义的工作流模板
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
class TemplateLoader:
|
class TemplateLoader:
|
||||||
"""工作流模板加载器"""
|
"""工作流模板加载器"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
@@ -123,7 +124,7 @@ class PromptOptimizerService:
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
current_prompt: str,
|
current_prompt: str,
|
||||||
user_require: str
|
user_require: str
|
||||||
) -> OptimizePromptResult:
|
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||||
"""
|
"""
|
||||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||||
|
|
||||||
@@ -161,6 +162,7 @@ class PromptOptimizerService:
|
|||||||
BusinessException: If the LLM response cannot be parsed as valid JSON
|
BusinessException: If the LLM response cannot be parsed as valid JSON
|
||||||
or does not conform to the expected output format.
|
or does not conform to the expected output format.
|
||||||
"""
|
"""
|
||||||
|
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
|
||||||
model_config = self.get_model_config(tenant_id, model_id)
|
model_config = self.get_model_config(tenant_id, model_id)
|
||||||
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
@@ -202,17 +204,54 @@ class PromptOptimizerService:
|
|||||||
messages.extend(session_history[:-1]) # last message is current message
|
messages.extend(session_history[:-1]) # last message is current message
|
||||||
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
messages.extend([(RoleType.USER.value, rendered_user_message)])
|
||||||
logger.info(f"Prompt optimization message: {messages}")
|
logger.info(f"Prompt optimization message: {messages}")
|
||||||
optim_resp = await llm.ainvoke(messages)
|
buffer = ""
|
||||||
logger.info(optim_resp.content)
|
prompt_started = False
|
||||||
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
|
prompt_finished = False
|
||||||
prompt = optim_result.get("prompt")
|
idx = 0
|
||||||
desc = optim_result.get("desc")
|
|
||||||
|
|
||||||
return OptimizePromptResult(
|
async for chunk in llm.astream(messages):
|
||||||
prompt=prompt,
|
content = getattr(chunk, "content", chunk)
|
||||||
desc=desc
|
if not content:
|
||||||
|
continue
|
||||||
|
buffer += content
|
||||||
|
cache = buffer[:-20]
|
||||||
|
|
||||||
|
# 尝试找到 "prompt": " 开始位置
|
||||||
|
if prompt_finished:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not prompt_started:
|
||||||
|
m = re.search(r'"prompt"\s*:\s*"', cache)
|
||||||
|
if m:
|
||||||
|
prompt_started = True
|
||||||
|
prompt_index = m.end()
|
||||||
|
idx = prompt_index
|
||||||
|
else:
|
||||||
|
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
|
||||||
|
if m:
|
||||||
|
prompt_index = m.start()
|
||||||
|
prompt_finished = True
|
||||||
|
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||||
|
else:
|
||||||
|
yield {"type": "delta", "content": cache[idx:]}
|
||||||
|
if len(cache) != 0:
|
||||||
|
idx = len(cache)
|
||||||
|
|
||||||
|
# optim_resp = await llm.astream(messages)
|
||||||
|
logger.info(buffer)
|
||||||
|
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||||
|
# prompt = optim_result.get("prompt")
|
||||||
|
desc = optim_result.get("desc")
|
||||||
|
self.create_message(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=RoleType.ASSISTANT,
|
||||||
|
content=desc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ Rules
|
|||||||
Basic Principles
|
Basic Principles
|
||||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||||
Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels.
|
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
||||||
Language Rule: All label languages must fully match the user input language.
|
Language Rule: All label languages must fully match the user input language.
|
||||||
|
|
||||||
Behavior Guidelines
|
Behavior Guidelines
|
||||||
|
|||||||
Reference in New Issue
Block a user