Files
MemoryBear/api/app/schemas/prompt_schema.py

62 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from jinja2 import Environment, Template, meta
from typing import Any, Dict
from enum import Enum
from pydantic import BaseModel, Field
from abc import ABC
from typing import Union, List
class PromptMessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class TextPromptMessageContent(BaseModel):
type: str = Field(default="text")
data: str
PromptMessageContentUnionTypes = TextPromptMessageContent
class PromptMessage(ABC, BaseModel):
role: PromptMessageRole
content: Union[str, List[PromptMessageContentUnionTypes], None] = None
name: Union[str, None] = None
model_config = {"arbitrary_types_allowed": True}
def is_empty(self) -> bool:
return not self.content
def get_text_content(self) -> str:
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
return "".join([item.data for item in self.content if isinstance(item, TextPromptMessageContent)])
return ""
def render_prompt_message(template_str: str, role: PromptMessageRole, params: Dict[str, Any]) -> PromptMessage:
"""
通用函数自动解析模板变量渲染PromptMessage
- template_str: Jinja2模板字符串
- role: PromptMessageRole
- params: 提供模板变量的字典
"""
env = Environment()
parsed_content = env.parse(template_str)
variables = meta.find_undeclared_variables(parsed_content)
# 检查缺失参数,如果缺失则给默认值 ''
for var in variables:
if var not in params:
params[var] = ""
# 渲染模板
jinja_template = Template(template_str)
rendered_text = jinja_template.render(**params)
return PromptMessage(
role=role,
content=[TextPromptMessageContent(data=rendered_text)]
)