Merge branch 'refs/heads/develop' into feature/20251219_xjn

This commit is contained in:
谢俊男
2025-12-18 16:58:45 +08:00
18 changed files with 2489 additions and 1375 deletions

View File

@@ -28,6 +28,7 @@ from . import (
public_share_controller,
multi_agent_controller,
workflow_controller,
prompt_optimizer_controller
)
# 创建管理端 API 路由器
@@ -58,5 +59,6 @@ manager_router.include_router(public_share_controller.router) # 公开路由(
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,13 +1,9 @@
from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
from typing import Optional
import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
@@ -39,7 +35,7 @@ def get_model_providers():
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"),
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -54,13 +50,21 @@ def get_model_list(
支持多个 type 参数:
- 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING
- 多个(逗号分隔)?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type,
type=type_list,
provider=provider,
is_active=is_active,
is_public=is_public,

View File

@@ -0,0 +1,170 @@
import uuid
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
router = APIRouter(prefix="/prompt", tags=["Prompts-Optimization"])
logger = get_api_logger()
@router.post(
"/sessions",
summary="Create a new prompt optimization session",
response_model=ApiResponse
)
def create_prompt_session(
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Create a new prompt optimization session for the current user.
Returns:
ApiResponse: Contains the newly generated session ID.
"""
service = PromptOptimizerService(db)
# create new session
session = service.create_session(current_user.tenant_id, current_user.id)
result_schema = CreateSessionResponse.model_validate(session)
return success(data=result_schema)
@router.get(
"/sessions/{session_id}",
summary="获取 prompt 优化历史对话",
response_model=ApiResponse
)
def get_prompt_session(
session_id: uuid.UUID = Path(..., description="Session ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve all messages from a specified prompt optimization session.
Args:
session_id (UUID): The ID of the session to retrieve
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains the session ID and the list of messages.
"""
service = PromptOptimizerService(db)
history = service.get_session_message_history(
session_id=session_id,
user_id=current_user.id
)
messages = [
SessionMessage(role=role, content=content)
for role, content in history
]
result = SessionHistoryResponse(
session_id=session_id,
messages=messages
)
return success(data=result)
@router.post(
"/sessions/{session_id}/messages",
summary="Get prompt optimization",
response_model=ApiResponse
)
async def get_prompt_opt(
session_id: uuid.UUID = Path(..., description="Session ID"),
data: PromptOptMessage = ...,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Send a user message in the specified session and return the optimized prompt
along with its description and variables.
Args:
session_id (UUID): The session ID
data (PromptOptMessage): Contains the user message, model ID, and current prompt
db (Session): Database session
current_user: Current user information
Returns:
ApiResponse: Contains the optimized prompt, description, and a list of variables.
"""
service = PromptOptimizerService(db)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.USER,
content=data.message
)
opt_result = await service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
message=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.ASSISTANT,
content=opt_result.desc
)
variables = service.parser_prompt_variables(opt_result.prompt)
result = {
"prompt": opt_result.prompt,
"desc": opt_result.desc,
"variables": variables
}
result_schema = OptimizePromptResponse.model_validate(result)
return success(data=result_schema)
@router.put(
"/model",
summary="Create or update prompt model config",
response_model=ApiResponse
)
def set_system_prompt(
data: PromptOptModelSet = ...,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Create or update a system prompt model configuration for the tenant.
Args:
data (PromptOptModelSet): Model configuration data including model ID,
system prompt, and optional configuration ID
db (Session): Database session
current_user: Current user information
Returns:
UUID: The ID of the created or updated model configuration.
"""
if data.id is None:
data.id = uuid.uuid4()
model_config = PromptOptimizerService(db).create_update_model_config(
current_user.tenant_id,
data.id,
data.system_prompt
)
return success(data=model_config.id)

View File

@@ -119,7 +119,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
@@ -194,7 +194,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
@@ -314,7 +314,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
_, msg = message_fit_in(hist, getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
@@ -341,7 +341,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@@ -350,7 +350,7 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@@ -378,7 +378,7 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
@@ -641,7 +641,7 @@ def split_chunks(chunks, max_length: int):
async def run_toc_from_text(chunks, chat_mdl, callback=None):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
input_budget = int(getattr(chat_mdl, 'max_length', 8096) * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)

View File

@@ -20,6 +20,7 @@ from .data_config_model import DataConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
from .prompt_optimizer_model import PromptOptimizerModelConfig, PromptOptimizerSession, PromptOptimizerSessionHistory
__all__ = [
"Tenants",
@@ -54,5 +55,8 @@ __all__ = [
"WorkflowConfig",
"WorkflowExecution",
"WorkflowNodeExecution",
"RetrievalInfo"
"RetrievalInfo",
"PromptOptimizerModelConfig",
"PromptOptimizerSession",
"PromptOptimizerSessionHistory"
]

View File

@@ -16,7 +16,26 @@ class Document(Base):
file_size = Column(Integer, default=0, comment="file size(byte)")
file_meta = Column(JSON, nullable=False, default={})
parser_id = Column(String, index=True, nullable=False, comment="default parser ID")
parser_config = Column(JSON, nullable=False, default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, comment="default parser config")
parser_config = Column(JSON, nullable=False,
default={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "general",
}
}, comment="default parser config")
chunk_num = Column(Integer, default=0, comment="chunk num")
progress = Column(Float, default=0)
progress_msg = Column(String, default="", comment="process message")

View File

@@ -56,7 +56,25 @@ class Knowledge(Base):
chunk_num = Column(Integer, default=0, comment="chunk num")
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
parser_config = Column(JSON, nullable=False,
default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"},
default={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "general",
}
},
comment="default parser config")
status = Column(Integer, index=True, default=1, comment="is it validate(0: disable, 1: enable, 2:Soft-delete)")
created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -15,6 +15,25 @@ class ModelType(StrEnum):
EMBEDDING = "embedding"
RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum):
"""模型提供商枚举"""

View File

@@ -0,0 +1,173 @@
import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class RoleType(StrEnum):
"""
Enumeration of message roles used in prompt optimization conversations.
This enum standardizes the role identifiers for messages stored in the
prompt optimization session history, ensuring consistency across
system-generated messages, user inputs, and assistant responses.
Attributes:
SYSTEM (str): Represents system-level instructions or prompts that
define the behavior or constraints of the assistant.
USER (str): Represents messages originating from the end user.
ASSISTANT (str): Represents messages generated by the AI assistant.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class PromptOptimizerModelConfig(Base):
"""
Prompt Optimization Model Configuration.
This table stores system-level prompt configurations for each tenant.
The configuration defines the base system prompt used during prompt
optimization sessions and serves as a foundational instruction set
for the optimization process.
Each tenant may have one or more model configurations depending on
business requirements.
Table Name:
prompt_model_config
Columns:
id (UUID):
Primary key. Unique identifier for the prompt model configuration.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant that owns this configuration.
system_prompt (Text):
The system-level prompt used to guide prompt optimization logic.
created_at (DateTime):
Timestamp indicating when the configuration was created.
updated_at (DateTime):
Timestamp indicating the last update time of the configuration.
Usage:
- Loaded when initializing a prompt optimization session
- Acts as the root system instruction for all subsequent prompts
"""
__tablename__ = "prompt_model_config"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# model_id = Column(UUID(as_uuid=True), nullable=False, comment="Model ID")
system_prompt = Column(Text, nullable=False, comment="System Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="Update Time")
class PromptOptimizerSession(Base):
"""
Prompt Optimization Session Registry.
This table records high-level metadata for prompt optimization sessions.
Each record represents a single logical session initiated by a user
under a specific tenant.
The session acts as a container for multiple conversation messages
stored in the session history table.
Table Name:
prompt_opt_session_list
Columns:
id (UUID):
Public-facing session identifier used to group conversation history.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant under which the session is created.
user_id (UUID):
Foreign key referencing `users.id`.
Identifies the user who initiated the session.
created_at (DateTime):
Timestamp indicating when the session was created.
Design Notes:
- This table intentionally does not store message content
- Message-level data is stored in `prompt_opt_session_history`
- Enables efficient session listing and pagination
"""
__tablename__ = "prompt_opt_session_list"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="Session ID")
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
class PromptOptimizerSessionHistory(Base):
"""
Prompt Optimization Session Message History.
This table stores the complete conversational history of a prompt
optimization session, including system prompts, user inputs, and
assistant responses.
Each record represents a single message within a session, preserving
the chronological order of interactions.
Table Name:
prompt_opt_session_history
Columns:
id (UUID):
Primary key. Unique identifier for the message record.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant under which the session operates.
session_id (UUID):
Logical session identifier linking messages to a session.
user_id (UUID):
Foreign key referencing `users.id`.
Identifies the user associated with the session.
message_role (Text):
Role of the message sender (e.g., system, user, assistant).
message_content (Text):
Raw message content generated or provided during the session.
prompt (Text):
The prompt snapshot used at the time of message generation.
created_at (DateTime):
Timestamp indicating when the message was created.
Design Notes:
- Supports full conversation replay and audit
- Enables prompt evolution tracking over time
- Indexed by creation time for efficient chronological queries
"""
__tablename__ = "prompt_opt_session_history"
__table_args__ = (
Index(
"ix_prompt_opt_session_history_session_user_created",
"session_id",
"user_id",
"created_at"
),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
role = Column(String, nullable=False, comment="Message Role")
content = Column(Text, nullable=False, comment="Message Content")
# prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)

View File

@@ -115,7 +115,9 @@ def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Kn
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
try:
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
knowledge = db.query(Knowledge).filter(Knowledge.name == name,
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1).first()
if knowledge:
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
else:

View File

@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
from typing import List, Optional, Dict, Any, Tuple
import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery
)
from app.core.logging_config import get_db_logger
@@ -32,7 +32,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -60,7 +60,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -92,7 +92,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -117,13 +117,21 @@ class ModelConfigRepository:
filters.append(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
# 支持多个 type 值(使用 IN 查询)
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
if query.type:
filters.append(ModelConfig.type.in_(query.type))
type_values = list(query.type)
# 如果包含 chat 或 llm则同时包含两者
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
if ModelType.CHAT not in type_values:
type_values.append(ModelType.CHAT)
if ModelType.LLM not in type_values:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
@@ -183,12 +191,12 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
if is_active:
query = query.filter(ModelConfig.is_active == True)
query = query.filter(ModelConfig.is_active)
models = query.order_by(ModelConfig.name).all()
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
@@ -285,7 +293,7 @@ class ModelConfigRepository:
try:
# 总数统计
total_models = db.query(ModelConfig).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
# 按类型统计
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
if is_active:
query = query.filter(ModelApiKey.is_active == True)
query = query.filter(ModelApiKey.is_active)
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")

View File

@@ -0,0 +1,229 @@
import uuid
from typing import Optional
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.prompt_optimizer_model import (
PromptOptimizerModelConfig,
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
)
db_logger = get_db_logger()
class PromptOptimizerModelConfigRepository:
"""Repository for managing prompt optimizer model configurations."""
def __init__(self, db: Session):
self.db = db
def get_by_tenant_id(self, tenant_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
"""
Retrieve the prompt optimizer model configuration for a specific tenant.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
Returns:
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
"""
db_logger.debug(f"Get prompt optimization model configuration: tenant_id={tenant_id}")
try:
config = self.db.query(PromptOptimizerModelConfig).filter(
PromptOptimizerModelConfig.tenant_id == tenant_id,
# PromptOptimizerModelConfig.model_id == model_id
).first()
if config:
db_logger.debug(f"Prompt optimization model configuration found: (ID: {config.id})")
else:
db_logger.debug(f"Prompt optimization model configuration not found: tenant_id={tenant_id}")
return config
except Exception as e:
db_logger.error(
f"Error retrieving prompt optimization model configuration: tenant_id={tenant_id} - {str(e)}")
raise
def get_by_config_id(self, tenant_id: uuid.UUID, config_id: uuid.UUID) -> Optional[PromptOptimizerModelConfig]:
"""
Retrieve a specific prompt optimizer model configuration by config ID and tenant ID.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
config_id (uuid.UUID): The unique identifier of the model configuration.
Returns:
Optional[PromptOptimizerModelConfig]: The model configuration if found, else None.
"""
db_logger.debug(f"Get prompt optimization model configuration: config_id={config_id}, tenant_id={tenant_id}")
try:
model = self.db.query(PromptOptimizerModelConfig).filter(
PromptOptimizerModelConfig.tenant_id == tenant_id,
PromptOptimizerModelConfig.id == config_id
).first()
if model:
db_logger.debug(f"Prompt optimization model configuration found: (ID: {model.id})")
else:
db_logger.debug(f"Prompt optimization model configuration not found: config_id={config_id}")
return model
except Exception as e:
db_logger.error(
f"Error retrieving prompt optimization model configuration: model_id={config_id} - {str(e)}")
raise
def create_or_update(
self,
config_id: uuid.UUID,
tenant_id: uuid.UUID,
system_prompt: str,
) -> Optional[PromptOptimizerModelConfig]:
"""
Create a new or update an existing prompt optimizer model configuration.
If a configuration with the given config_id exists, it updates its system_prompt.
Otherwise, it creates a new configuration record.
Args:
config_id (uuid.UUID): The unique identifier for the configuration.
tenant_id (uuid.UUID): The tenant's unique identifier.
system_prompt (str): The system prompt content for prompt optimization.
Returns:
Optional[PromptOptimizerModelConfig]: The created or updated model configuration.
"""
db_logger.debug(f"Create/Update prompt optimization model configuration: tenant_id={tenant_id}")
existing_config = self.get_by_config_id(tenant_id, config_id)
if existing_config:
existing_config.system_prompt = system_prompt
self.db.commit()
self.db.refresh(existing_config)
db_logger.debug(f"Prompt optimization model configuration update: ID:{config_id}")
return existing_config
else:
config = PromptOptimizerModelConfig(
id=config_id,
# model_id=model_id,
tenant_id=tenant_id,
system_prompt=system_prompt
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
db_logger.debug(f"Prompt optimization model configuration created: ID:{config.id}")
return config
class PromptOptimizerSessionRepository:
"""Repository for managing prompt optimization sessions and session history."""
def __init__(self, db: Session):
self.db = db
def create_session(
self,
tenant_id: uuid.UUID,
user_id: uuid.UUID
) -> PromptOptimizerSession:
"""
Create a new prompt optimization session for a user and app.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
PromptOptimizerSession: The newly created session object.
"""
db_logger.debug(f"Create prompt optimization session: tenant_id={tenant_id}, user_id={user_id}")
try:
session = PromptOptimizerSession(
tenant_id=tenant_id,
user_id=user_id,
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
return session
except Exception as e:
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
raise
def get_session_history(
self,
session_id: uuid.UUID,
user_id: uuid.UUID
) -> list[type[PromptOptimizerSessionHistory]]:
"""
Retrieve all message history of a specific prompt optimization session.
Args:
session_id (uuid.UUID): The unique identifier of the session.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
list[PromptOptimizerSessionHistory]: A list of session history records
ordered by creation time ascending.
"""
db_logger.debug(f"Get prompt optimization session history: "
f"user_id={user_id}, session_id={session_id}")
try:
# First get the internal session ID from the session list table
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id
).first()
if not session:
return []
history = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session.id,
PromptOptimizerSessionHistory.user_id == user_id
).order_by(PromptOptimizerSessionHistory.created_at.asc()).all()
return history
except Exception as e:
db_logger.error(f"Error retrieving prompt optimization session history: session_id={session_id} - {str(e)}")
raise
def create_message(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
role: RoleType,
content: str,
) -> PromptOptimizerSessionHistory:
"""
Create a new message in the session history.
This method is a placeholder for future implementation.
"""
try:
# Get the session to ensure it exists and belongs to the user
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id,
PromptOptimizerSession.tenant_id == tenant_id
).first()
if not session:
db_logger.error(f"Session {session_id} not found for user {user_id}")
raise ValueError(f"Session {session_id} not found for user {user_id}")
message = PromptOptimizerSessionHistory(
tenant_id=tenant_id,
session_id=session.id,
user_id=user_id,
role=role.value,
content=content,
)
self.db.add(message)
self.db.commit()
return message
except Exception as e:
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
raise

View File

@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field
from uuid import UUID
# =========================================
# API Request Schemas
# =========================================
class PromptOptMessage(BaseModel):
model_id: UUID = Field(
...,
description="Model ID"
)
message: str = Field(
...,
min_length=1,
description="User's input message"
)
current_prompt: str = Field(
default="",
description="currently optimized prompt"
)
class PromptOptModelSet(BaseModel):
id: UUID | None = Field(
default=None,
description="Configuration ID"
)
system_prompt: str = Field(
...,
description="System Prompt"
)
# =========================================
# Service Layer Results
# =========================================
class OptimizePromptResult(BaseModel):
prompt: str = Field(
...,
description="Optimized Prompt"
)
desc: str = Field(
...,
description="Description"
)
# =========================================
# API Response Schemas
# =========================================
class CreateSessionResponse(BaseModel):
model_config = {"from_attributes": True}
id: UUID = Field(
...,
description="Session ID"
)
class OptimizePromptResponse(BaseModel):
model_config = {"from_attributes": True}
prompt: str = Field(
...,
description="Optimized Prompt"
)
desc: str = Field(
...,
description="Description"
)
variables: list = Field(
...,
description="Variables"
)
class SessionMessage(BaseModel):
role: str = Field(
...,
description="Message role (user/assistant)"
)
content: str = Field(
...,
description="Message content"
)
class SessionHistoryResponse(BaseModel):
session_id: UUID = Field(
...,
description="Session ID"
)
messages: list[SessionMessage] = Field(
...,
description="List of messages in the session"
)

View File

@@ -0,0 +1,280 @@
import json
import re
import uuid
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
from app.models.prompt_optimizer_model import (
PromptOptimizerModelConfig,
PromptOptimizerSession,
RoleType
)
from app.repositories.model_repository import ModelConfigRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerModelConfigRepository,
PromptOptimizerSessionRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
logger = get_business_logger()
class PromptOptimizerService:
def __init__(self, db: Session):
self.db = db
def get_model_config(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID
) -> tuple[PromptOptimizerModelConfig, ModelConfig]:
"""
Retrieve the prompt optimizer model configuration and model configuration.
This method retrieves the prompt optimizer model configuration associated
with the specified model ID and tenant. It also fetches the corresponding
model configuration.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the prompt optimization model.
Returns:
tuple[PromptOptimzerModelConfig, ModelConfig]:
A tuple containing the prompt optimizer model configuration
and the corresponding model configuration.
Raises:
BusinessException: If the prompt optimizer model configuration does not exist.
BusinessException: If the model configuration does not exist.
"""
prompt_config = PromptOptimizerModelConfigRepository(self.db).get_by_tenant_id(
tenant_id
)
if not prompt_config:
raise BusinessException("提示词模型配置不存在", BizCode.NOT_FOUND)
model = ModelConfigRepository.get_by_id(
self.db, model_id, tenant_id=tenant_id
)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return prompt_config, model
def create_update_model_config(
self,
tenant_id: uuid.UUID,
config_id: uuid.UUID,
system_prompt: str,
) -> PromptOptimizerModelConfig:
"""
Create or update a prompt optimizer model configuration.
This method creates a new prompt optimizer model configuration or updates
an existing one identified by the given configuration ID. The configuration
defines the system prompt used for prompt optimization.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
config_id (uuid.UUID): The unique identifier of the configuration to create or update.
system_prompt (str): The system prompt content used for prompt optimization.
Returns:
PromptOptimzerModelConfig: The created or updated prompt optimizer model configuration.
"""
prompt_config = PromptOptimizerModelConfigRepository(self.db).create_or_update(
config_id=config_id,
tenant_id=tenant_id,
system_prompt=system_prompt,
)
return prompt_config
def create_session(
self,
tenant_id: uuid.UUID,
user_id: uuid.UUID
) -> PromptOptimizerSession:
"""
Create a new prompt optimization session.
This method initializes a new prompt optimization session for the specified
tenant, application, and user, and persists it to the database.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
PromptOptimzerSession: The newly created prompt optimization session.
"""
session = PromptOptimizerSessionRepository(self.db).create_session(
tenant_id=tenant_id,
user_id=user_id
)
return session
def get_session_message_history(
self,
session_id: uuid.UUID,
user_id: uuid.UUID
) -> list[tuple[str, str]]:
"""
Retrieve the chronological message history for a prompt optimization session.
This method queries the database to fetch all messages associated with a
specific prompt optimization session for a given user. Messages are returned
in chronological order and typically include both user inputs and
model-generated responses.
Args:
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
user_id (uuid.UUID): The unique identifier of the user associated with the session.
Returns:
list[tuple[str, str]]: A list of tuples representing messages. Each tuple contains:
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
- content (str): The content of the message.
"""
history = PromptOptimizerSessionRepository(self.db).get_session_history(
session_id=session_id,
user_id=user_id
)
messages = []
for message in history:
messages.append((message.role, message.content))
return messages
async def optimize_prompt(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
message: str
) -> OptimizePromptResult:
"""
Optimize a prompt using a prompt optimizer LLM.
This method uses a configured prompt optimizer model to refine an existing
prompt based on the user's requirements. The optimized prompt is generated
according to predefined system rules, including Jinja2 variable syntax and
a strict JSON output format.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the prompt optimizer model.
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
user_id (uuid.UUID): The unique identifier of the user associated with the session.
current_prompt (str): The original prompt to be optimized.
message (str): The user's requirements or modification instructions.
Returns:
dict: A dictionary containing the optimized prompt and the description
of changes, in the following format:
{
"prompt": "<optimized_prompt>",
"desc": "<change_description>"
}
Raises:
BusinessException: If the model response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
prompt_config, model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
# Create LLM instance
api_config: ModelApiKey = model_config.api_keys[0]
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
), type=ModelType.from_str(model_config.type))
# build message
messages = [
# init system_prompt
(RoleType.SYSTEM.value, prompt_config.system_prompt),
# base model limit
(RoleType.SYSTEM.value,
"Optimization Rules:\n"
"1. Fully adjust the prompt content according to the user's requirements.\n"
"2. When the user requests the insertion of variables, you must use Jinja2 syntax {{variable_name}} "
"(the variable name should be determined based on the user's requirement).\n"
"3. Keep the prompt logic clear and instructions explicit.\n"
"4. Ensure that the modified prompt can be directly used.\n\n"
"Output Requirements:\n"
"Provide the result in JSON format, containing exactly two fields:\n"
" - prompt: The modified prompt (string).\n"
" - desc: A response addressing the user's optimization request (string).")
]
messages.extend(session_history[:-1]) # last message is current message
user_message_template = ChatPromptTemplate.from_messages([
(RoleType.USER.value, "[current_prompt]\n{current_prompt}\n[user_require]\n{message}")
])
formatted_user_message = user_message_template.format(current_prompt=current_prompt, message=message)
messages.extend([(RoleType.USER.value, formatted_user_message)])
logger.info(f"Prompt optimization message: {messages}")
result = await llm.ainvoke(messages)
try:
data_dict = json.loads(result.content)
model_resp = OptimizePromptResult.model_validate(data_dict)
except Exception as e:
logger.error(f"Failed to parse model reponse to json - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to parse model response", BizCode.PARSER_NOT_SUPPORTED)
return model_resp
@staticmethod
def parser_prompt_variables(prompt: str):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
matches = re.findall(pattern, prompt)
variables = list(set(matches))
return variables
except Exception as e:
logger.error(f"Failed to parse prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to parse prompt variables", BizCode.PARSER_NOT_SUPPORTED)
@staticmethod
def fill_prompt_variables(prompt: str, variables: dict[str, str]):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
def replace_var(match):
var_name = match.group(1)
return variables.get(var_name, match.group(0))
result = re.sub(pattern, replace_var, prompt)
return result
except Exception as e:
logger.error(f"Failed to fill prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to fill prompt variables", BizCode.PARSER_NOT_SUPPORTED)
def create_message(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
role: RoleType,
content: str
) -> PromptOptimizerSessionHistory:
"""Insert Message to Session History"""
message = PromptOptimizerSessionRepository(self.db).create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=role,
content=content
)
return message

View File

@@ -0,0 +1,74 @@
"""202512171846
Revision ID: 87a6537b4074
Revises: 64ddbf3c3bcc
Create Date: 2025-12-17 18:45:16.574812
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '87a6537b4074'
down_revision: Union[str, None] = '64ddbf3c3bcc'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('prompt_model_config',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'),
sa.Column('system_prompt', sa.Text(), nullable=False, comment='System Prompt'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'),
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='Update Time'),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_prompt_model_config_id'), 'prompt_model_config', ['id'], unique=False)
op.create_table('prompt_opt_session_list',
sa.Column('id', sa.UUID(), nullable=False, comment='Session ID'),
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'),
sa.Column('user_id', sa.UUID(), nullable=False, comment='User ID'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_prompt_opt_session_list_created_at'), 'prompt_opt_session_list', ['created_at'], unique=False)
op.create_index(op.f('ix_prompt_opt_session_list_id'), 'prompt_opt_session_list', ['id'], unique=False)
op.create_table('prompt_opt_session_history',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'),
sa.Column('session_id', sa.UUID(), nullable=False, comment='Session ID'),
sa.Column('user_id', sa.UUID(), nullable=False, comment='User ID'),
sa.Column('role', sa.String(), nullable=False, comment='Message Role'),
sa.Column('content', sa.Text(), nullable=False, comment='Message Content'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'),
sa.ForeignKeyConstraint(['session_id'], ['prompt_opt_session_list.id'], ),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_prompt_opt_session_history_created_at'), 'prompt_opt_session_history', ['created_at'], unique=False)
op.create_index(op.f('ix_prompt_opt_session_history_id'), 'prompt_opt_session_history', ['id'], unique=False)
op.create_index('ix_prompt_opt_session_history_session_user_created', 'prompt_opt_session_history', ['session_id', 'user_id', 'created_at'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_prompt_opt_session_history_session_user_created', table_name='prompt_opt_session_history')
op.drop_index(op.f('ix_prompt_opt_session_history_id'), table_name='prompt_opt_session_history')
op.drop_index(op.f('ix_prompt_opt_session_history_created_at'), table_name='prompt_opt_session_history')
op.drop_table('prompt_opt_session_history')
op.drop_index(op.f('ix_prompt_opt_session_list_id'), table_name='prompt_opt_session_list')
op.drop_index(op.f('ix_prompt_opt_session_list_created_at'), table_name='prompt_opt_session_list')
op.drop_table('prompt_opt_session_list')
op.drop_index(op.f('ix_prompt_model_config_id'), table_name='prompt_model_config')
op.drop_table('prompt_model_config')
# ### end Alembic commands ###

View File

@@ -126,6 +126,7 @@ dependencies = [
"pytest-asyncio>=1.3.0",
"uvicorn>=0.34.0",
"celery>=5.5.2",
"simpleeval>=1.0.3",
]
[tool.pytest.ini_options]

View File

@@ -121,3 +121,4 @@ fastmcp>=2.13.1
pytest-asyncio>=1.3.0
uvicorn>=0.34.0
celery>=5.5.2
simpleeval>=1.0.3

2701
api/uv.lock generated

File diff suppressed because it is too large Load Diff