Files
MemoryBear/api/app/services/prompt_optimizer_service.py
Ke Sun 79ab929fb0 Release/v0.2.3 (#355)
* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(prompt): add history tracking for prompt releases

* feat(web): add prompt menu

* refactor: The PageScrollList component supports two generic parameters

* feat(web): BodyWrapper compoent update PageLoading

* feat(web): add Ontology menu

* feat(web): memory management add scene

* feat(tasks): add celery task configuration for periodic jobs

- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs

* feat(sandbox): add Node.js code execution support to sandbox

* Release/v0.2.2 (#260)

* [modify] migration script

* [add] migration script

* fix(web): change form message

* fix(web): the memoryContent field is compatible with numbers and strings

* feat(web): code node hidden

* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.

---------

Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>

* Feature/ontology class clean (#249)

* [add] Complete ontology engineering feature implementation

* [add] Add ontology feature integration and validation utilities

* [add] Add OWL validator and validation utilities

* [fix] Add missing render_ontology_extraction_prompt function

* [fix]Add dependencies, fix functionality

* [add] migration script

* feat(celery): add dedicated periodic tasks worker and queue (#261)

* fix(web): conflict resolve

* Fix/v022 bug (#263)

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [add]Add scene_id

* [fix]Based on the AI review to fix the code

* Fix/develop memory reflex (#265)

* 遗漏的历史映射

* 遗漏的历史映射

* 反思后台报错处理

* [add] migration script

* fix: chat conversation_id add node_start

* feat(web): show code node

* fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated

* feat(web): RadioGroupCard support block mode

* feat(web): create space add icon

* feat(app and model): token consumption statistics

* Add/develop memory (#264)

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 新增长期记忆功能

* 新增长期记忆功能

* 新增长期记忆功能

* 知识库检索多余字段

* 长期

* feat(app and model): token consumption statistics of the cluster

* memory_BUG_fix

* fix(web): prompt history remove pageLoading

* fix(prompt): remove hard-coded import of prompt file paths (#279)

* Fix/develop memory bug (#274)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix(web): update retrieve_type key

* Fix/develop memory bug (#276)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix_timeline_memories

* write_gragp/bug_fix

* write_gragp/bug_fix

* write_gragp/bug_fix

* chore(celery): disable periodic task scheduling

* fix(prompt): remove hard-coded import of prompt file paths

---------

Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>

* fix(web): remove delete confirm content

* refactor(workflow): relocate template directory into workflow

* feat(memory): add long-term storage task routing and batching

* fix(web): PageScrollList loading update

* fix(web): PageScrollList loading update

* Ontology v1 bug (#291)

* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]Set the default values of the memory configuration model based on the spatial model.

* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]

* fix(memory): add Redis session validation

- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing

* fix(web): PageScrollList style update

* fix(workflow): fix argument passing in code execution nodes

* fix(web): prompt add disabled

* fix(web): space icon required

* feat(app): modify the key of the token

* fix(fix the key of the app's token):

* fix(workflow): switch code input encoding to base64+URL encoding

* [add]The main project adds multi-API Key load balancing.

* [changes]Attribute security access, secure numerical conversion, unified use of local variables

* fix(web): save add session update

* fix(web): language editor support paste

* [changes]Active status filtering logic, API Key selection strategy

* memory_BUG

* memory_BUG_long_term

* [changes]

* memory_BUG_long_term

* memory_BUG_long_term

* Fix/release memory bug (#306)

* memory_BUG_fix

* memory_BUG

* memory_BUG_long_term

* memory_BUG_long_term

* memory_BUG_long_term

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios

* fix(web): replace code editor

* [changes]Modify the description of the time for the recent event

* [changes]Modify the code based on the AI review

* feat(web): update memory config ontology api

* fix(web): ui update

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* feat(workflow): add token usage statistics for question classifier and parameter extraction

* feat(web): move prompt menu

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Write Missing None (#321)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/release memory bug (#324)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/writer memory bug (#326)

* [fix]Fix the bug

* [fix]Fix the bug

* [fix]Correct the direction indication.

* fix(web): markdown table ui update

* Fix/release memory bug (#332)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/fact summary (#333)

* [fix]Disable the contents related to fact_summary

* [fix]Disable the contents related to fact_summary

* [fix]Modify the code based on the AI review

* Fix/release memory bug (#335)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

* writer_graph_bug/fix

* writer_graph_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Revert "feat(web): move prompt menu"

This reverts commit 9e6e8f50f8.

* fix(web): ui update

* fix(web): update text

* fix(web): ui update

* fix(model): change the "vl" model type of dashscope to "chat"

* fix(model): change the "vl" model type of dashscope to "chat"

---------

Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: Eternity <1533512157@qq.com>
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com>
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: lixinyue <2569494688@qq.com>
Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com>
Co-authored-by: lanceyq <1982376970@qq.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-06 19:01:57 +08:00

468 lines
17 KiB
Python

import os
import re
import uuid
from typing import Any, AsyncGenerator
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.models import ModelConfig, ModelApiKey, ModelType, PromptOptimizerSessionHistory
from app.models.prompt_optimizer_model import (
PromptOptimizerSession,
RoleType
)
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerSessionRepository,
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
logger = get_business_logger()
class PromptOptimizerService:
def __init__(self, db: Session):
self.db = db
self.optim_repo = PromptOptimizerSessionRepository(self.db)
self.release_repo = PromptReleaseRepository(self.db)
def get_model_config(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID
) -> ModelConfig:
"""
Retrieve the model configuration for a specific tenant.
This method fetches the model configuration associated with the given
tenant_id and model_id. If no configuration is found, a BusinessException
is raised.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
model_id (uuid.UUID): The unique identifier of the model.
Returns:
ModelConfig: The corresponding model configuration object.
Raises:
BusinessException: If the model configuration does not exist.
"""
model = ModelConfigRepository.get_by_id(
self.db, model_id, tenant_id=tenant_id
)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
def create_session(
self,
tenant_id: uuid.UUID,
user_id: uuid.UUID
) -> PromptOptimizerSession:
"""
Create a new prompt optimization session.
This method initializes a new prompt optimization session for the specified
tenant, application, and user, and persists it to the database.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
PromptOptimzerSession: The newly created prompt optimization session.
"""
session = self.optim_repo.create_session(
tenant_id=tenant_id,
user_id=user_id
)
self.db.commit()
self.db.refresh(session)
return session
def get_session_message_history(
self,
session_id: uuid.UUID,
user_id: uuid.UUID
) -> list[tuple[str, str]]:
"""
Retrieve the chronological message history for a prompt optimization session.
This method queries the database to fetch all messages associated with a
specific prompt optimization session for a given user. Messages are returned
in chronological order and typically include both user inputs and
model-generated responses.
Args:
session_id (uuid.UUID): The unique identifier of the prompt optimization session.
user_id (uuid.UUID): The unique identifier of the user associated with the session.
Returns:
list[tuple[str, str]]: A list of tuples representing messages. Each tuple contains:
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
- content (str): The content of the message.
"""
history = self.optim_repo.get_session_history(
session_id=session_id,
user_id=user_id
)
messages = []
for message in history:
messages.append((message.role, message.content))
return messages
async def optimize_prompt(
self,
tenant_id: uuid.UUID,
model_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
user_require: str
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
This method refines the original prompt according to the user's requirements,
generating an optimized version that is directly usable by AI tools. The
optimization process follows strict rules, including:
- Wrapping user-inserted variables in double curly braces {{}}.
- Adhering to Jinja2 variable syntax if applicable.
- Ensuring a clear logic flow, explicit instructions, and strong executability.
- Producing output in a strict JSON format.
Steps performed:
1. Retrieve the model configuration for the given tenant and model.
2. Fetch the session message history for context.
3. Instantiate the LLM with the appropriate API key and model configuration.
4. Build system messages outlining optimization rules.
5. Format the user's original prompt and requirements as a user message.
6. Send messages to the LLM to generate the optimized prompt.
7. Generate a concise description summarizing the changes made during optimization.
Args:
tenant_id (uuid.UUID): Tenant identifier.
model_id (uuid.UUID): Prompt optimizer model identifier.
session_id (uuid.UUID): Prompt optimization session identifier.
user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization.
Returns:
OptimizePromptResult: An object containing:
- prompt: The optimized prompt string.
- desc: A short description summarizing the changes.
Raises:
BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
api_config: ModelApiKey = api_keys[0] if api_keys else None
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render()
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
except Exception as e:
logger.error(f"Failed to load system prompt template: {e}")
raise BusinessException(message="Internal server error", code=BizCode.INTERNAL_ERROR)
rendered_user_message = Template(opt_user_prompt).render(
current_prompt=current_prompt,
user_require=user_require
)
# build message
messages = [
# init system_prompt
(
RoleType.SYSTEM.value,
rendered_system_message
),
]
messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)])
buffer = ""
prompt_started = False
prompt_finished = False
idx = 0
async for chunk in llm.astream(messages):
content = getattr(chunk, "content", chunk)
if not content:
continue
buffer += content
cache = buffer[:-20]
# 尝试找到 "prompt": " 开始位置
if prompt_finished:
continue
if not prompt_started:
m = re.search(r'"prompt"\s*:\s*"', cache)
if m:
prompt_started = True
prompt_index = m.end()
idx = prompt_index
else:
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
if m:
prompt_index = m.start()
prompt_finished = True
yield {"content": buffer[idx:prompt_index]}
else:
yield {"content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
# optim_resp = await llm.astream(messages)
logger.info(buffer)
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
self.create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=RoleType.ASSISTANT,
content=desc
)
variables = self.parser_prompt_variables(optim_result.get("prompt"))
logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}")
yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod
def parser_prompt_variables(prompt: str):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
matches = re.findall(pattern, prompt)
variables = list(set(matches))
return variables
except Exception as e:
logger.error(f"Failed to parse prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to parse prompt variables", BizCode.PARSER_NOT_SUPPORTED)
@staticmethod
def fill_prompt_variables(prompt: str, variables: dict[str, str]):
try:
pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}'
def replace_var(match):
var_name = match.group(1)
return variables.get(var_name, match.group(0))
result = re.sub(pattern, replace_var, prompt)
return result
except Exception as e:
logger.error(f"Failed to fill prompt variables - Error: {str(e)}", exc_info=True)
raise BusinessException("Failed to fill prompt variables", BizCode.PARSER_NOT_SUPPORTED)
def create_message(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
role: RoleType,
content: str
) -> PromptOptimizerSessionHistory:
"""Insert Message to Session History"""
message = PromptOptimizerSessionRepository(self.db).create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=role,
content=content
)
self.db.commit()
self.db.refresh(message)
return message
def save_prompt(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
title: str,
prompt: str
) -> dict:
"""
Create and save a new prompt release for a given session.
Args:
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
session_id (uuid.UUID): The ID of the session to associate with this prompt.
title (str): The title of the prompt release.
prompt (str): The content of the prompt.
Returns:
dict: A dictionary containing:
- id (UUID): The unique ID of the created prompt release.
- session_id (UUID): The session ID linked to the release.
- title (str): The title of the prompt.
- prompt (str): The prompt content.
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
Raises:
BusinessException: If a prompt release already exists for the given session.
"""
session = self.optim_repo.get_session_by_id(session_id)
if session is None or session.tenant_id != tenant_id:
raise BusinessException(
"Session does not exist or the current user has no access",
BizCode.BAD_REQUEST
)
if self.release_repo.get_prompt_by_session_id(session_id):
raise BusinessException(
"A release already exists for the current session",
BizCode.BAD_REQUEST
)
prompt_obj = self.release_repo.create_prompt_release(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt
)
self.db.commit()
self.db.refresh(prompt_obj)
return {
"id": prompt_obj.id,
"session_id": prompt_obj.session_id,
"title": prompt_obj.title,
"prompt": prompt_obj.prompt,
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
}
def delete_prompt(
self,
tenant_id: uuid.UUID,
prompt_id: uuid.UUID
) -> None:
"""
Soft delete a prompt release by prompt_id.
Args:
tenant_id (uuid.UUID): Tenant identifier.
prompt_id (uuid.UUID): Prompt identifier.
Raises:
BusinessException: If the prompt does not exist or already deleted.
"""
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
if not prompt_obj or prompt_obj.is_delete:
raise BusinessException(
"Prompt does not exist or has already been deleted",
BizCode.NOT_FOUND
)
if prompt_obj.tenant_id != tenant_id:
raise BusinessException(
"No permission to delete this prompt",
BizCode.FORBIDDEN
)
self.release_repo.soft_delete_prompt(prompt_obj)
self.db.commit()
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
def get_release_list(
self,
tenant_id: uuid.UUID,
page: int,
page_size: int,
filter_keyword: str | None = None
) -> dict[str, int | list[Any]]:
"""
Get paginated list of prompt releases with optional filter.
Args:
tenant_id (uuid.UUID): Tenant identifier.
page (int): Page number (starting from 1).
page_size (int): Number of items per page.
filter_keyword (str | None): Optional keyword to filter by title.
Returns:
dict: Contains total count, pagination info, and list of releases.
"""
offset = (page - 1) * page_size
# Get total count and releases based on filter
if filter_keyword:
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
releases = self.release_repo.search_prompts_paginated(
tenant_id=tenant_id,
keyword=filter_keyword,
offset=offset,
limit=page_size
)
else:
total = self.release_repo.count_prompts(tenant_id)
releases = self.release_repo.get_prompts_paginated(
tenant_id=tenant_id,
offset=offset,
limit=page_size
)
items = []
for release in releases:
# Get first user message from session
first_message = self.optim_repo.get_first_user_message(
session_id=release.session_id
)
items.append({
"id": release.id,
"title": release.title,
"prompt": release.prompt,
"created_at": int(release.created_at.timestamp() * 1000),
"first_message": first_message
})
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
if filter_keyword:
log_msg += f", filter='{filter_keyword}'"
logger.info(log_msg)
result = {
"page": {
"total": total,
"page": page,
"page_size": page_size,
"hasnext": page * page_size < total
},
"keyword": filter_keyword,
"items": items
}
return result