Merge branch 'refs/heads/release/v0.2.3' into fix/release_memory_bug
# Conflicts: # api/app/core/memory/agent/langgraph_graph/write_graph.py
This commit is contained in:
@@ -3,9 +3,14 @@ import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
@@ -63,6 +68,11 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
@@ -79,40 +89,40 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"regenerate-memory-cache": {
|
||||
"task": "app.tasks.regenerate_memory_cache",
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
}
|
||||
# beat_schedule_config = {
|
||||
# "run-workspace-reflection": {
|
||||
# "task": "app.tasks.workspace_reflection_task",
|
||||
# "schedule": workspace_reflection_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "regenerate-memory-cache": {
|
||||
# "task": "app.tasks.regenerate_memory_cache",
|
||||
# "schedule": memory_cache_regeneration_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "run-forgetting-cycle": {
|
||||
# "task": "app.tasks.run_forgetting_cycle_task",
|
||||
# "schedule": forgetting_cycle_schedule,
|
||||
# "kwargs": {
|
||||
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
# if settings.DEFAULT_WORKSPACE_ID:
|
||||
# beat_schedule_config["write-total-memory"] = {
|
||||
# "task": "app.controllers.memory_storage_controller.search_all",
|
||||
# "schedule": memory_increment_schedule,
|
||||
# "kwargs": {
|
||||
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
# },
|
||||
# }
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
# celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -182,14 +182,6 @@ def _get_ontology_service(
|
||||
detail=f"找不到指定的LLM模型: {llm_id}"
|
||||
)
|
||||
|
||||
# 检查是否为组合模型
|
||||
if hasattr(model_config, 'is_composite') and model_config.is_composite:
|
||||
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="本体提取不支持使用组合模型,请选择单个模型"
|
||||
)
|
||||
|
||||
# 验证模型配置了API密钥
|
||||
if not model_config.api_keys:
|
||||
logger.error(f"Model {llm_id} has no API key configuration")
|
||||
|
||||
@@ -148,8 +148,10 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
return messages
|
||||
|
||||
# TODO: 移到memory module
|
||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||
db = next(get_db())
|
||||
#TODO: 魔法数字
|
||||
scope=6
|
||||
|
||||
try:
|
||||
@@ -159,6 +161,12 @@ class LangChainAgent:
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
|
||||
# Handle case where no session exists in Redis (returns False)
|
||||
if not result or result is False:
|
||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
||||
return
|
||||
|
||||
if type=="chunk" or type=="aggregate":
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
@@ -166,7 +174,14 @@ class LangChainAgent:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'写入短长期:')
|
||||
else:
|
||||
# TODO: This branch handles type="time" strategy, currently unused.
|
||||
# Will be activated when time-based long-term storage is implemented.
|
||||
# TODO: 魔法数字 - extract 5 to a constant
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
||||
return
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
@@ -307,9 +322,12 @@ class LangChainAgent:
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
'''长期'''
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||
response = {
|
||||
"content": content,
|
||||
@@ -441,9 +459,13 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
# TODO:删除
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
print(contents)
|
||||
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
redis_messages = []
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not long_time_data or long_time_data is False:
|
||||
return
|
||||
format_messages = await chat_data_format(long_time_data)
|
||||
if format_messages!=[]:
|
||||
await write_messages(end_user_id, format_messages, memory_config)
|
||||
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
|
||||
# Handle case where no session exists in Redis (returns False or empty)
|
||||
if not result or result is False:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
|
||||
@@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
||||
description="code content"
|
||||
)
|
||||
|
||||
language: Literal['python3', 'nodejs'] = Field(
|
||||
language: Literal['python3', 'javascript'] = Field(
|
||||
...,
|
||||
description="language"
|
||||
)
|
||||
|
||||
@@ -110,7 +110,7 @@ class CodeNode(BaseNode):
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
elif self.typed_config.language == 'nodejs':
|
||||
elif self.typed_config.language == 'javascript':
|
||||
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
|
||||
@@ -4,16 +4,19 @@
|
||||
从文件系统加载预定义的工作流模板
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
"""工作流模板加载器"""
|
||||
|
||||
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||
|
||||
def __init__(self, templates_dir: str = TEMPLATE_DIR):
|
||||
"""初始化模板加载器
|
||||
|
||||
Args:
|
||||
@@ -22,7 +25,7 @@ class TemplateLoader:
|
||||
self.templates_dir = Path(templates_dir)
|
||||
if not self.templates_dir.exists():
|
||||
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""列出所有可用的模板
|
||||
|
||||
@@ -30,22 +33,22 @@ class TemplateLoader:
|
||||
模板列表,每个模板包含 id, name, description 等信息
|
||||
"""
|
||||
templates = []
|
||||
|
||||
|
||||
# 遍历模板目录
|
||||
for template_dir in self.templates_dir.iterdir():
|
||||
if not template_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否有 template.yml 文件
|
||||
template_file = template_dir / "template.yml"
|
||||
if not template_file.exists():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 读取模板配置
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
|
||||
# 提取模板信息
|
||||
templates.append({
|
||||
"id": template_dir.name,
|
||||
@@ -59,9 +62,9 @@ class TemplateLoader:
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
def load_template(self, template_id: str) -> Optional[dict]:
|
||||
"""加载指定的模板
|
||||
|
||||
@@ -73,14 +76,14 @@ class TemplateLoader:
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
template_file = template_dir / "template.yml"
|
||||
|
||||
|
||||
if not template_file.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
|
||||
# 返回工作流配置部分
|
||||
return {
|
||||
"name": template_data.get("name", template_id),
|
||||
@@ -94,7 +97,7 @@ class TemplateLoader:
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_id} 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||
"""获取模板的 README 文档
|
||||
|
||||
@@ -106,10 +109,10 @@ class TemplateLoader:
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
readme_file = template_dir / "README.md"
|
||||
|
||||
|
||||
if not readme_file.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
@@ -235,6 +235,8 @@ class MemoryConfigRepository:
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
reflection_model_id=params.reflection_model_id,
|
||||
emotion_model_id=params.emotion_model_id,
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
@@ -877,7 +877,8 @@ RETURN
|
||||
CASE
|
||||
WHEN ms:ExtractedEntity THEN {
|
||||
text: ms.name,
|
||||
created_at: ms.created_at
|
||||
created_at: ms.created_at,
|
||||
type: "情景记忆"
|
||||
}
|
||||
END
|
||||
) AS ExtractedEntity,
|
||||
@@ -887,7 +888,8 @@ RETURN
|
||||
CASE
|
||||
WHEN n:MemorySummary THEN {
|
||||
text: n.content,
|
||||
created_at: n.created_at
|
||||
created_at: n.created_at,
|
||||
type: "长期沉淀"
|
||||
}
|
||||
END
|
||||
) AS MemorySummary,
|
||||
@@ -895,7 +897,8 @@ RETURN
|
||||
collect(
|
||||
DISTINCT {
|
||||
text: e.statement,
|
||||
created_at: e.created_at
|
||||
created_at: e.created_at,
|
||||
type: "情绪记忆"
|
||||
}
|
||||
) AS statement;
|
||||
"""
|
||||
|
||||
@@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
|
||||
@@ -187,7 +187,7 @@ class AppStatisticsService:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""会话服务"""
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
@@ -529,12 +530,12 @@ class ConversationService:
|
||||
takeaways=[],
|
||||
info_score=0,
|
||||
)
|
||||
|
||||
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
rendered_system_message = Template(system_prompt).render()
|
||||
|
||||
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
user_prompt = f.read()
|
||||
rendered_user_message = Template(user_prompt).render(
|
||||
language=language,
|
||||
|
||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
||||
app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.app_id.in_(app_ids)
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.updated_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
|
||||
@@ -286,7 +286,7 @@ class MemoryReflectionService:
|
||||
# 检查是否需要执行反思
|
||||
should_execute = False
|
||||
hours_diff = 0
|
||||
|
||||
|
||||
if current_reflection_time is None:
|
||||
# 首次执行反思
|
||||
should_execute = True
|
||||
@@ -298,11 +298,11 @@ class MemoryReflectionService:
|
||||
reflection_time = datetime.fromisoformat(current_reflection_time)
|
||||
else:
|
||||
reflection_time = current_reflection_time
|
||||
|
||||
|
||||
current_time = datetime.now()
|
||||
time_diff = current_time - reflection_time
|
||||
hours_diff = int(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
# 检查是否达到反思周期
|
||||
if hours_diff >= iteration_period:
|
||||
should_execute = True
|
||||
@@ -312,7 +312,7 @@ class MemoryReflectionService:
|
||||
except (ValueError, TypeError) as e:
|
||||
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
|
||||
should_execute = True
|
||||
|
||||
|
||||
if should_execute:
|
||||
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
|
||||
# 3. 执行反思引擎
|
||||
@@ -345,7 +345,7 @@ class MemoryReflectionService:
|
||||
"next_reflection_in_hours": iteration_period - hours_diff
|
||||
}
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
config_id = config_data.get("config_id", "unknown")
|
||||
api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||
@@ -356,7 +356,7 @@ class MemoryReflectionService:
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data
|
||||
}
|
||||
|
||||
|
||||
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
|
||||
"""Create reflective configuration objects from configuration data"""
|
||||
|
||||
@@ -364,12 +364,12 @@ class MemoryReflectionService:
|
||||
if reflexion_range_value is None or reflexion_range_value == "":
|
||||
reflexion_range_value = "partial"
|
||||
reflexion_range = ReflectionRange(reflexion_range_value)
|
||||
|
||||
|
||||
baseline_value = config_data.get("baseline")
|
||||
if baseline_value is None or baseline_value == "":
|
||||
baseline_value = "TIME"
|
||||
baseline = ReflectionBaseline(baseline_value)
|
||||
|
||||
|
||||
# iteration_period =
|
||||
iteration_period = config_data.get("iteration_period", 24)
|
||||
if isinstance(iteration_period, str):
|
||||
@@ -377,7 +377,6 @@ class MemoryReflectionService:
|
||||
iteration_period = int(iteration_period)
|
||||
except (ValueError, TypeError):
|
||||
iteration_period = 24 # 默认24小时
|
||||
|
||||
return ReflectionConfig(
|
||||
enabled=config_data.get("enable_self_reflexion", False),
|
||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||
|
||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||
if not params.reflection_model_id:
|
||||
params.reflection_model_id = params.llm_id
|
||||
if not params.emotion_model_id:
|
||||
params.emotion_model_id = params.llm_id
|
||||
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"end_user_id": config.end_user_id,
|
||||
"config_id_old": config_id_old,
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": config.scene_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
@@ -182,11 +183,12 @@ class PromptOptimizerService:
|
||||
base_url=api_config.api_base
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render()
|
||||
|
||||
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_user_prompt = f.read()
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
|
||||
288
api/app/tasks.py
288
api/app/tasks.py
@@ -1066,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
@@ -1204,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||
def long_term_storage_window_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
langchain_messages: List[Dict[str, Any]],
|
||||
config_id: str,
|
||||
scope: int = 6
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task for window-based long-term memory storage.
|
||||
|
||||
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||
then writes batched messages to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: End user identifier
|
||||
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
config_id: Memory configuration ID
|
||||
scope: Window size (number of messages before triggering write)
|
||||
|
||||
Returns:
|
||||
Dict containing task status and metadata
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Save to Redis buffer first
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# Load memory config
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="LongTermStorageTask"
|
||||
)
|
||||
|
||||
# Execute window-based dialogue storage
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
|
||||
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
**result,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"strategy": "window",
|
||||
"error": str(e),
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||
# def long_term_storage_time_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# config_id: str,
|
||||
# time_window: int = 5
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for time-based long-term memory storage.
|
||||
|
||||
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# config_id: Memory configuration ID
|
||||
# time_window: Time window in minutes for retrieving recent sessions
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute time-based storage
|
||||
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||
|
||||
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "time",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
|
||||
# def long_term_storage_aggregate_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# langchain_messages: List[Dict[str, Any]],
|
||||
# config_id: str
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for aggregate-based long-term memory storage.
|
||||
|
||||
# Uses LLM to determine if new messages describe the same event as history.
|
||||
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
# config_id: Memory configuration ID
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status, is_same_event flag, and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Save to Redis buffer first
|
||||
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute aggregate judgment
|
||||
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
# return {
|
||||
# "status": "SUCCESS",
|
||||
# "strategy": "aggregate",
|
||||
# "is_same_event": result.get("is_same_event", False),
|
||||
# "wrote_to_neo4j": not result.get("is_same_event", False)
|
||||
# }
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "aggregate",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
Reference in New Issue
Block a user