fix(prompt): remove hard-coded import of prompt file paths (#279)

* Fix/develop memory bug (#274)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix(web): update retrieve_type key

* Fix/develop memory bug (#276)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix_timeline_memories

* write_gragp/bug_fix

* write_gragp/bug_fix

* write_gragp/bug_fix

* chore(celery): disable periodic task scheduling

* fix(prompt): remove hard-coded import of prompt file paths

---------

Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>
This commit is contained in:
Eternity
2026-02-03 10:29:51 +08:00
committed by GitHub
parent 5d5351f0bc
commit b471d56a86
9 changed files with 62 additions and 64 deletions

View File

@@ -3,9 +3,10 @@ import platform
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
from app.core.config import settings
from celery import Celery from celery import Celery
from app.core.config import settings
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0 # broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10 # backend: 结果存储(使用 Redis DB 10
@@ -79,40 +80,40 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app']) celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks # Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) # memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) # memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME # workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 # forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { # beat_schedule_config = {
"run-workspace-reflection": { # "run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task", # "task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule, # "schedule": workspace_reflection_schedule,
"args": (), # "args": (),
}, # },
"regenerate-memory-cache": { # "regenerate-memory-cache": {
"task": "app.tasks.regenerate_memory_cache", # "task": "app.tasks.regenerate_memory_cache",
"schedule": memory_cache_regeneration_schedule, # "schedule": memory_cache_regeneration_schedule,
"args": (), # "args": (),
}, # },
"run-forgetting-cycle": { # "run-forgetting-cycle": {
"task": "app.tasks.run_forgetting_cycle_task", # "task": "app.tasks.run_forgetting_cycle_task",
"schedule": forgetting_cycle_schedule, # "schedule": forgetting_cycle_schedule,
"kwargs": { # "kwargs": {
"config_id": None, # 使用默认配置,可以通过环境变量配置 # "config_id": None, # 使用默认配置,可以通过环境变量配置
}, # },
}, # },
} # }
# 如果配置了默认工作空间ID则添加记忆总量统计任务 # 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID: # if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = { # beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all", # "task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule, # "schedule": memory_increment_schedule,
"kwargs": { # "kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID, # "workspace_id": settings.DEFAULT_WORKSPACE_ID,
}, # },
} # }
celery_app.conf.beat_schedule = beat_schedule_config # celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -39,7 +39,6 @@ async def make_write_graph():
graph = workflow.compile() graph = workflow.compile()
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
@@ -49,7 +48,7 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
db_session = next(get_db()) db_session = next(get_db())
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id="08ed205c-0f05-49c3-8e0c-a580d28f5fd4", # 改为整数 config_id=memory_config, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type=='chunk': if long_term_type=='chunk':
@@ -63,7 +62,7 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
"""方案三:聚合判断""" """方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
#
# async def main(): # async def main():
# """主函数 - 运行工作流""" # """主函数 - 运行工作流"""
# langchain_messages = [ # langchain_messages = [
@@ -80,14 +79,7 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) # # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# from app.core.memory.agent.utils.redis_tool import write_store # result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=write_store.get_session_by_userid(end_user_id)
# data=await format_parsing(result,"dict")
# chunk_data=data[:6]
#
# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240)
# long_=await messages_parse(long_time_data)
# print(long_)
# #
# #
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@@ -877,7 +877,8 @@ RETURN
CASE CASE
WHEN ms:ExtractedEntity THEN { WHEN ms:ExtractedEntity THEN {
text: ms.name, text: ms.name,
created_at: ms.created_at created_at: ms.created_at,
type: "情景记忆"
} }
END END
) AS ExtractedEntity, ) AS ExtractedEntity,
@@ -887,7 +888,8 @@ RETURN
CASE CASE
WHEN n:MemorySummary THEN { WHEN n:MemorySummary THEN {
text: n.content, text: n.content,
created_at: n.created_at created_at: n.created_at,
type: "长期沉淀"
} }
END END
) AS MemorySummary, ) AS MemorySummary,
@@ -895,7 +897,8 @@ RETURN
collect( collect(
DISTINCT { DISTINCT {
text: e.statement, text: e.statement,
created_at: e.created_at created_at: e.created_at,
type: "情绪记忆"
} }
) AS statement; ) AS statement;
""" """

View File

@@ -1,4 +1,5 @@
"""会话服务""" """会话服务"""
import os
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated from typing import Annotated
@@ -529,12 +530,12 @@ class ConversationService:
takeaways=[], takeaways=[],
info_score=0, info_score=0,
) )
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
system_prompt = f.read() system_prompt = f.read()
rendered_system_message = Template(system_prompt).render() rendered_system_message = Template(system_prompt).render()
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
user_prompt = f.read() user_prompt = f.read()
rendered_user_message = Template(user_prompt).render( rendered_user_message = Template(user_prompt).render(
language=language, language=language,

View File

@@ -377,7 +377,6 @@ class MemoryReflectionService:
iteration_period = int(iteration_period) iteration_period = int(iteration_period)
except (ValueError, TypeError): except (ValueError, TypeError):
iteration_period = 24 # 默认24小时 iteration_period = 24 # 默认24小时
return ReflectionConfig( return ReflectionConfig(
enabled=config_data.get("enable_self_reflexion", False), enabled=config_data.get("enable_self_reflexion", False),
iteration_period=str(iteration_period), # ReflectionConfig期望字符串 iteration_period=str(iteration_period), # ReflectionConfig期望字符串

View File

@@ -1,3 +1,4 @@
import os
import re import re
import uuid import uuid
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
@@ -182,11 +183,12 @@ class PromptOptimizerService:
base_url=api_config.api_base base_url=api_config.api_base
), type=ModelType(model_config.type)) ), type=ModelType(model_config.type))
try: try:
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read() opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render() rendered_system_message = Template(opt_system_prompt).render()
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read() opt_user_prompt = f.read()
except FileNotFoundError: except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)

View File

@@ -64,7 +64,7 @@ const KnowledgeListModal = forwardRef<KnowledgeModalRef, KnowledgeModalProps>(({
...item, ...item,
config: { config: {
similarity_threshold: 0.7, similarity_threshold: 0.7,
strategy: "hybrid", retrieve_type: "hybrid",
top_k: 3, top_k: 3,
weight: 1, weight: 1,
} }

View File

@@ -64,7 +64,7 @@ const KnowledgeListModal = forwardRef<KnowledgeModalRef, KnowledgeModalProps>(({
...item, ...item,
config: { config: {
similarity_threshold: 0.7, similarity_threshold: 0.7,
strategy: "hybrid", retrieve_type: "hybrid",
top_k: 3, top_k: 3,
weight: 1, weight: 1,
} }

View File

@@ -885,7 +885,7 @@ export const useWorkflowGraph = ({
...itemConfig, ...itemConfig,
...(data.config[key].defaultValue || {}), ...(data.config[key].defaultValue || {}),
knowledge_bases: knowledge_bases?.map((vo: any) => { knowledge_bases: knowledge_bases?.map((vo: any) => {
const kb_config = vo.config || { similarity_threshold: vo.similarity_threshold, strategy: vo.strategy, top_k: vo.top_k, weight: vo.weight } const kb_config = vo.config || { similarity_threshold: vo.similarity_threshold, retrieve_type: vo.retrieve_type, top_k: vo.top_k, weight: vo.weight }
return { kb_id: vo.kb_id || vo.id, ...kb_config, } return { kb_id: vo.kb_id || vo.id, ...kb_config, }
}) })
} }