From 6920deef638e75229cf9834f365df160c70aabd5 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 21 Jan 2026 11:33:52 +0800 Subject: [PATCH 1/7] Fix/memory bug fix (#162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察) * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 --- .../controllers/memory_agent_controller.py | 18 +- .../memory_short_term_controller.py | 1 + .../controllers/user_memory_controllers.py | 46 +++- api/app/schemas/emotion_schema.py | 5 + api/app/schemas/end_user_schema.py | 1 + api/app/schemas/memory_episodic_schema.py | 2 + api/app/schemas/memory_explicit_schema.py | 4 + api/app/services/memory_agent_service.py | 13 +- api/app/services/memory_base_service.py | 259 +++++++++++++++++- .../memory_entity_relationship_service.py | 156 +++++++++-- api/app/services/user_memory_service.py | 63 +++-- 11 files changed, 511 insertions(+), 57 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b7da943c..46fe3043 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -9,7 +9,7 @@ from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User -from app.repositories import knowledge_repository +from app.repositories import knowledge_repository, WorkspaceRepository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse from app.services import task_service, workspace_service @@ -616,8 +616,10 @@ async def get_knowledge_type_stats_api( @router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) async def get_hot_memory_tags_by_user_api( end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + language_type: Optional[str] ="zh", limit: int = Query(20, description="返回标签数量限制"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session=Depends(get_db), ): """ 获取指定用户的热门记忆标签 @@ -628,10 +630,22 @@ async def get_hot_memory_tags_by_user_api( ... ] """ + + workspace_id=current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None + api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") try: result = await memory_agent_service.get_hot_memory_tags_by_user( end_user_id=end_user_id, + language_type=language_type, + model_id=model_id, limit=limit ) return success(data=result, msg="获取热门记忆标签成功") diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 64991f4d..9cf66749 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -20,6 +20,7 @@ router = APIRouter( @router.get("/short_term") async def short_term_configs( end_user_id: str, + language_type:Optional[str] = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index a96c7a52..d99eb47e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -12,6 +12,7 @@ from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode from app.core.api_key_utils import timestamp_to_datetime +from app.services.memory_base_service import Translation_English from app.services.user_memory_service import ( UserMemoryService, analytics_memory_types, @@ -20,7 +21,7 @@ from app.services.user_memory_service import ( from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest - +from app.repositories.workspace_repository import WorkspaceRepository from app.schemas.end_user_schema import ( EndUserProfileResponse, EndUserProfileUpdate, @@ -44,6 +45,7 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( end_user_id: str, + language_type: str = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -53,10 +55,18 @@ async def get_memory_insight_report_api( 此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。 如需生成新的洞察报告,请使用专门的生成接口。 """ + workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_memory_insight(db, end_user_id) + result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type) if result["is_cached"]: api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") @@ -72,6 +82,7 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( end_user_id: str, + language_type: str="zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -81,10 +92,18 @@ async def get_user_summary_api( 此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。 如需生成新的用户摘要,请使用专门的生成接口。 """ + workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id) + result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -253,7 +272,6 @@ async def get_graph_data_api( depth=depth, center_node_id=center_node_id ) - # 检查是否有错误消息 if "message" in result and result["statistics"]["total_nodes"] == 0: api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") @@ -278,7 +296,13 @@ async def get_end_user_profile( db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") @@ -296,7 +320,6 @@ async def get_end_user_profile( if not end_user: api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - # 构建响应数据 profile_data = EndUserProfileResponse( id=end_user.id, @@ -396,12 +419,21 @@ async def update_end_user_profile( return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str, +async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): + workspace_id=current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None MemoryEntity = MemoryEntityService(id, label) - timeline_memories_result = await MemoryEntity.get_timeline_memories_server() + timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type) + return success(data=timeline_memories_result, msg="共同记忆时间线") @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, diff --git a/api/app/schemas/emotion_schema.py b/api/app/schemas/emotion_schema.py index 5175fed1..cfa65b0f 100644 --- a/api/app/schemas/emotion_schema.py +++ b/api/app/schemas/emotion_schema.py @@ -11,6 +11,7 @@ class EmotionTagsRequest(BaseModel): start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)") end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)") limit: int = Field(10, ge=1, le=100, description="返回数量限制") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionWordcloudRequest(BaseModel): @@ -18,20 +19,24 @@ class EmotionWordcloudRequest(BaseModel): group_id: str = Field(..., description="组ID") emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") limit: int = Field(50, ge=1, le=200, description="返回词语数量") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionHealthRequest(BaseModel): """获取情绪健康指数请求""" group_id: str = Field(..., description="组ID") time_range: str = Field("30d", description="时间范围(7d/30d/90d)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionSuggestionsRequest(BaseModel): """获取个性化情绪建议请求""" group_id: str = Field(..., description="组ID") config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionGenerateSuggestionsRequest(BaseModel): """生成个性化情绪建议请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index c9f9146d..6f7498a0 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -44,6 +44,7 @@ class EndUserProfileResponse(BaseModel): updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None) + class EndUserProfileUpdate(BaseModel): """终端用户基本信息更新请求模型""" end_user_id: str = Field(description="终端用户ID") diff --git a/api/app/schemas/memory_episodic_schema.py b/api/app/schemas/memory_episodic_schema.py index 832bf34b..74e68837 100644 --- a/api/app/schemas/memory_episodic_schema.py +++ b/api/app/schemas/memory_episodic_schema.py @@ -51,6 +51,7 @@ class EpisodicMemoryOverviewRequest(BaseModel): """情景记忆总览查询请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") time_range: str = Field( default="all", description="时间范围筛选,可选值:all, today, this_week, this_month" @@ -70,3 +71,4 @@ class EpisodicMemoryDetailsRequest(BaseModel): end_user_id: str = Field(..., description="终端用户ID") summary_id: str = Field(..., description="情景记忆摘要ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/schemas/memory_explicit_schema.py b/api/app/schemas/memory_explicit_schema.py index c2b51a81..823a3116 100644 --- a/api/app/schemas/memory_explicit_schema.py +++ b/api/app/schemas/memory_explicit_schema.py @@ -1,15 +1,19 @@ """ 显性记忆的请求和响应模型 """ +from typing import Optional + from pydantic import BaseModel, Field class ExplicitMemoryOverviewRequest(BaseModel): """显性记忆总览查询请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class ExplicitMemoryDetailsRequest(BaseModel): """显性记忆详情查询请求""" end_user_id: str = Field(..., description="终端用户ID") memory_id: str = Field(..., description="记忆ID(情景记忆或语义记忆的ID)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index c9230a26..fd0cb0eb 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -26,6 +26,7 @@ from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError +from app.services.memory_base_service import Translation_English from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, @@ -692,7 +693,9 @@ class MemoryAgentService: async def get_hot_memory_tags_by_user( self, end_user_id: Optional[str] = None, - limit: int = 20 + limit: int = 20, + model_id: Optional[str] = None, + language_type: Optional[str] = "zh" ) -> List[Dict[str, Any]]: """ 获取指定用户的热门记忆标签 @@ -710,7 +713,13 @@ class MemoryAgentService: try: # by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度) tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [{"name": t, "frequency": f} for t, f in tags] + payload=[] + for tag, freq in tags: + if language_type!="zh": + tag=await Translation_English(model_id, tag) + payload.append({"name": tag, "frequency": freq}) + else: + payload.append({"name": tag, "frequency": freq}) return payload except Exception as e: logger.error(f"热门记忆标签查询失败: {e}") diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index 6f844ae9..25a8281d 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -3,17 +3,268 @@ Memory Base Service 提供记忆服务的基础功能和共享辅助方法。 """ - +import asyncio +import re from datetime import datetime from typing import Optional - +from pydantic import BaseModel from app.core.logging_config import get_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.emotion_analytics_service import EmotionAnalyticsService - +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.models.base import RedBearModelConfig +from app.services.memory_config_service import MemoryConfigService +from app.db import get_db_context logger = get_logger(__name__) +class TranslationResponse(BaseModel): + """翻译响应模型""" + data: str + +class MemoryTransService: + """记忆翻译服务,提供中英文翻译功能""" + + def __init__(self, llm_client=None, model_id: Optional[str] = None): + """ + 初始化翻译服务 + + Args: + llm_client: LLM客户端实例或模型ID字符串(可选) + model_id: 模型ID,用于初始化LLM客户端(可选) + + Note: + - 如果llm_client是字符串,会被当作model_id使用 + - 如果同时提供llm_client和model_id,优先使用llm_client + """ + # 处理llm_client参数:如果是字符串,当作model_id + if isinstance(llm_client, str): + self.model_id = llm_client + self.llm_client = None + else: + self.llm_client = llm_client + self.model_id = model_id + + self._initialized = False + + def _ensure_llm_client(self): + """确保LLM客户端已初始化""" + if self._initialized: + return + + if self.llm_client is None: + if self.model_id: + with get_db_context() as db: + config_service = MemoryConfigService(db) + model_config = config_service.get_model_config(self.model_id) + + extra_params = { + "temperature": 0.2, + "max_tokens": 400, + "top_p": 0.8, + "stream": False, + } + + self.llm_client = OpenAIClient( + RedBearModelConfig( + model_name=model_config.get("model_name"), + provider=model_config.get("provider"), + api_key=model_config.get("api_key"), + base_url=model_config.get("base_url"), + timeout=model_config.get("timeout", 30), + max_retries=model_config.get("max_retries", 3), + extra_params=extra_params + ), + type_=model_config.get("type") + ) + else: + raise ValueError("必须提供 llm_client 或 model_id 之一") + + self._initialized = True + + async def translate_to_english(self, text: str) -> str: + """ + 将中文翻译为英文 + + Args: + text: 要翻译的中文文本 + + Returns: + 翻译后的英文文本 + """ + self._ensure_llm_client() + + translation_messages = [ + { + "role": "user", + "content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}" + } + ] + + try: + response = await self.llm_client.response_structured( + messages=translation_messages, + response_model=TranslationResponse + ) + return response.data + except Exception as e: + logger.error(f"翻译失败: {str(e)}") + return text # 翻译失败时返回原文 + + async def is_english(self, text: str) -> bool: + """ + 检查文本是否为英文 + + Args: + text: 要检查的文本(必须是字符串) + + Returns: + True 如果文本主要是英文,False 否则 + + Note: + - 只接受字符串类型 + - 检查是否主要由英文字母和常见标点组成 + - 允许数字、空格和常见标点符号 + """ + if not isinstance(text, str): + raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}") + + if not text.strip(): + return True # 空字符串视为英文 + + # 更宽松的英文检查:允许字母、数字、空格和常见标点 + # 如果文本中英文字符占比超过 80%,认为是英文 + english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-')) + total_chars = len(text) + + if total_chars == 0: + return True + + return (english_chars / total_chars) >= 0.8 + async def Translate(self, text: str, target_language: str = "en") -> str: + """ + 通用翻译方法(保持向后兼容) + + Args: + text: 要翻译的文本 + target_language: 目标语言,"en"表示英文,"zh"表示中文 + + Returns: + 翻译后的文本 + """ + if target_language == "en": + return await self.translate_to_english(text) + else: + logger.warning(f"不支持的目标语言: {target_language},返回原文") + return text + + # 测试翻译服务 +async def Translation_English(modid, text, fields=None): + """ + 将数据翻译为英文(支持字段级翻译) + + Args: + modid: 模型ID + text: 要翻译的数据(可以是字符串、字典或列表) + fields: 需要翻译的字段列表(可选) + 如果为None,默认翻译: ['content', 'summary', 'statement', 'description', + 'name', 'aliases', 'caption', 'emotion_keywords'] + + Returns: + 翻译后的数据,保持原有结构 + + Note: + - 对于字符串:直接翻译 + - 对于列表:递归处理每个元素,保持列表长度和索引不变 + - 对于字典:只翻译指定字段(fields参数) + - 对于其他类型:原样返回 + """ + trans_service = MemoryTransService(modid) + + # 处理字符串类型 + if isinstance(text, str): + # 空字符串直接返回 + if not text.strip(): + return text + + try: + is_eng = await trans_service.is_english(text) + if not is_eng: + english_result = await trans_service.Translate(text) + return english_result + return text + except Exception as e: + logger.warning(f"翻译字符串失败: {e}") + return text + + # 处理列表类型 + elif isinstance(text, list): + english_result = [] + for item in text: + # 递归处理列表中的每个元素 + if isinstance(item, str): + # 字符串元素:检查是否需要翻译 + if not item.strip(): + english_result.append(item) + continue + + try: + is_eng = await trans_service.is_english(item) + if not is_eng: + translated = await trans_service.Translate(item) + english_result.append(translated) + else: + # 保留英文项,不改变列表长度 + english_result.append(item) + except Exception as e: + logger.warning(f"翻译列表项失败: {e}") + english_result.append(item) + + elif isinstance(item, dict): + # 字典元素:递归调用自己处理字典 + translated_dict = await Translation_English(modid, item, fields) + english_result.append(translated_dict) + + elif isinstance(item, list): + # 嵌套列表:递归处理 + translated_list = await Translation_English(modid, item, fields) + english_result.append(translated_list) + + else: + # 其他类型(数字、布尔值等):原样保留 + english_result.append(item) + + return english_result + + # 处理字典类型 + elif isinstance(text, dict): + # 确定要翻译的字段 + if fields is None: + # 默认翻译字段 + fields = [ + 'content', 'summary', 'statement', 'description', + 'name', 'aliases', 'caption', 'emotion_keywords', + 'text', 'title', 'label', 'type' # 添加常用字段 + ] + + # 创建副本,避免修改原始数据 + result = text.copy() + + for field in fields: + if field in result and result[field] is not None: + # 递归翻译字段值(可能是字符串、列表或嵌套字典) + try: + result[field] = await Translation_English(modid, result[field], fields) + except Exception as e: + logger.warning(f"翻译字段 {field} 失败: {e}") + # 翻译失败时保留原值 + continue + + return result + + # 其他类型(数字、布尔值、None等):原样返回 + else: + return text class MemoryBaseService: """记忆服务基类,提供共享的辅助方法""" @@ -294,4 +545,4 @@ class MemoryBaseService: except Exception as e: logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True) - return 0 + return 0 \ No newline at end of file diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py index eedb7c29..9b5f3c99 100644 --- a/api/app/services/memory_entity_relationship_service.py +++ b/api/app/services/memory_entity_relationship_service.py @@ -16,6 +16,7 @@ import json from datetime import datetime from app.schemas.memory_episodic_schema import EmotionType +from app.services.memory_base_service import Translation_English logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ class MemoryEntityService: self.id = id self.table = table self.connector = Neo4jConnector() - async def get_timeline_memories_server(self): + async def get_timeline_memories_server(self,model_id, language_type): """ 获取时间线记忆数据 @@ -48,10 +49,10 @@ class MemoryEntityService: logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}") # 根据表类型选择查询 - if self.table == 'Statement': + if self.table == 'Statement': # Statement只需要输入ID,使用简化查询 results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id) - elif self.table == 'ExtractedEntity': + elif self.table == 'ExtractedEntity': # ExtractedEntity类型查询 results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id) else: @@ -62,7 +63,7 @@ class MemoryEntityService: logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}") # 处理查询结果 - timeline_data = self._process_timeline_results(results) + timeline_data =await self._process_timeline_results(results, model_id, language_type) logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条") @@ -71,12 +72,14 @@ class MemoryEntityService: except Exception as e: logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True) return str(e) - def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: + async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]: """ 处理时间线查询结果 Args: results: Neo4j查询结果 + model_id: 模型ID用于翻译 + language_type: 语言类型 ('zh' 或其他) Returns: 处理后的时间线数据字典 @@ -104,19 +107,19 @@ class MemoryEntityService: # 处理MemorySummary summary = data.get('MemorySummary') if summary is not None: - processed_summary = self._process_field_value(summary, "MemorySummary") + processed_summary = await self._process_field_value(summary, "MemorySummary") memory_summary_list.extend(processed_summary) # 处理Statement statement = data.get('statement') if statement is not None: - processed_statement = self._process_field_value(statement, "Statement") + processed_statement = await self._process_field_value(statement, "Statement") statement_list.extend(processed_statement) # 处理ExtractedEntity extracted_entity = data.get('ExtractedEntity') if extracted_entity is not None: - processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity") + processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity") extracted_entity_list.extend(processed_entity) # 去重 - 现在处理的是字典列表,需要更智能的去重 @@ -128,6 +131,8 @@ class MemoryEntityService: all_timeline_data = memory_summary_list + statement_list all_timeline_data = self._merge_same_text_items(all_timeline_data) + # 如果需要翻译(非中文),对整个结果进行翻译 + result = { "MemorySummary": memory_summary_list, "Statement": statement_list, @@ -233,7 +238,7 @@ class MemoryEntityService: except Exception: return False - def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]: + async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]: """ 处理字段值,支持字符串、列表等类型 @@ -251,13 +256,13 @@ class MemoryEntityService: # 如果是列表,处理每个元素 for item in value: if self._is_valid_item(item): - processed_item = self._process_single_item(item) + processed_item = await self._process_single_item(item) if processed_item: processed_values.append(processed_item) elif isinstance(value, dict): # 如果是字典,直接处理 if self._is_valid_item(value): - processed_item = self._process_single_item(value) + processed_item = await self._process_single_item(value) if processed_item: processed_values.append(processed_item) elif isinstance(value, str): @@ -304,7 +309,7 @@ class MemoryEntityService: return (str(item).strip() != '' and "MemorySummaryChunk" not in str(item)) - def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]: + async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ 处理单个项目 @@ -369,6 +374,117 @@ class MemoryEntityService: logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}") return str(dt) if dt is not None else None + async def _translate_list( + self, + data_list: List[Dict[str, Any]], + model_id: str, + fields: List[str] + ) -> List[Dict[str, Any]]: + """ + 翻译列表中每个字典的指定字段(并发有限度以降低整体延迟) + + Args: + data_list: 要翻译的字典列表 + model_id: 模型ID + fields: 需要翻译的字段列表 + + Returns: + 翻译后的字典列表 + """ + # 空列表或无字段时直接返回 + if not data_list or not fields: + return data_list + + import asyncio + + # 并发限制,避免一次性发起过多请求 + # 可根据实际情况调整(建议 5-10) + concurrency_limit = 5 + semaphore = asyncio.Semaphore(concurrency_limit) + + async def translate_single_field( + index: int, + field: str, + value: Any, + ) -> Optional[tuple]: + """ + 翻译单个字段并返回 (索引, 字段名, 翻译结果) + + Returns: + (index, field, translated_value) 或 None(如果跳过) + """ + # 跳过空值 + if value is None or value == "": + return None + + # 统一转成字符串再翻译,防止非字符串类型导致错误 + text = str(value) + + try: + async with semaphore: + # 调用 Translation_English 进行翻译 + # 注意:Translation_English 的参数顺序是 (model_id, text) + translated = await Translation_English(model_id, text) + + # 如果翻译结果为空,保留原值 + if translated is None or translated == "": + return None + + return index, field, translated + except Exception as e: + logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}") + return None + + # 构造所有需要翻译的任务 + tasks = [] + for idx, item in enumerate(data_list): + # 防御性检查:确保 item 是字典 + if not isinstance(item, dict): + continue + + for field in fields: + if field not in item: + continue + + value = item.get(field) + + # 对于 None 或空字符串的值,直接跳过,不创建任务 + if value is None or value == "": + continue + + tasks.append( + asyncio.create_task( + translate_single_field(idx, field, value) + ) + ) + + # 如果没有需要翻译的任务,直接返回原列表 + if not tasks: + return data_list + + # 使用 gather 并发执行翻译任务(受 semaphore 限制) + # return_exceptions=True 可以防止单个任务失败导致整体失败 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 创建深拷贝以避免修改原始数据 + translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list] + + # 将翻译结果回填到列表 + for result in results: + # 跳过 None 结果和异常 + if result is None or isinstance(result, Exception): + if isinstance(result, Exception): + logger.warning(f"翻译任务异常: {result}") + continue + + idx, field, translated = result + + # 防御性检查索引范围 + if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict): + translated_list[idx][field] = translated + + return translated_list + @@ -426,15 +542,19 @@ class MemoryEmotion: # 如果解析失败,返回原始字符串 return iso_string - async def get_emotion(self) -> Dict[str, Any]: + async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]: """ 获取情绪随时间变化数据 + Args: + model_id: 模型ID用于翻译 + language_type: 语言类型 ('zh' 或其他) + Returns: 包含情绪数据的字典 """ try: - logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}") + logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}") if self.table == 'Statement': results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id) @@ -450,6 +570,10 @@ class MemoryEmotion: # 转换Neo4j类型 final_data = self._convert_neo4j_types(emotion_data) + # 如果需要翻译(非中文) + if language_type != 'zh' and model_id and final_data: + final_data = await self._translate_emotion_data(final_data, model_id) + logger.info(f"成功获取 {len(final_data)} 条情绪数据") return final_data @@ -590,16 +714,14 @@ class MemoryInteraction: """ try: logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}") - ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) if ori_data!=[]: # name = ori_data[0]['name'] - group_id = ori_data[0]['group_id'] + group_id = [i['group_id'] for i in ori_data][0] Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) if not Space_User: return [] user_id=Space_User[0]['id'] - results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 9221ab06..ae07256a 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService -from app.services.memory_base_service import MemoryBaseService +from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -360,7 +360,9 @@ class UserMemoryService: async def get_cached_memory_insight( self, db: Session, - end_user_id: str + end_user_id: str, + model_id: str, + language_type: str ) -> Dict[str, Any]: """ 从数据库获取缓存的记忆洞察(四个维度) @@ -419,11 +421,18 @@ class UserMemoryService: key_findings_array = [] logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)") + memory_insight=end_user.memory_insight + behavior_pattern=end_user.behavior_pattern + growth_trajectory=end_user.growth_trajectory + if language_type!='zh': + memory_insight=await Translation_English(model_id,memory_insight) + behavior_pattern=await Translation_English(model_id,behavior_pattern) + growth_trajectory=await Translation_English(model_id,growth_trajectory) return { - "memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight - "behavior_pattern": end_user.behavior_pattern, + "memory_insight":memory_insight, # 总体概述存储在 memory_insight + "behavior_pattern":behavior_pattern, "key_findings": key_findings_array, # 返回数组 - "growth_trajectory": end_user.growth_trajectory, + "growth_trajectory": growth_trajectory, "updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at), "is_cached": True } @@ -457,7 +466,9 @@ class UserMemoryService: async def get_cached_user_summary( self, db: Session, - end_user_id: str + end_user_id: str, + model_id:str, + language_type:str="zh" ) -> Dict[str, Any]: """ 从数据库获取缓存的用户摘要(四个部分) @@ -481,7 +492,6 @@ class UserMemoryService: user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -495,20 +505,29 @@ class UserMemoryService: } # 检查是否有缓存数据(至少有一个字段不为空) + user_summary=end_user.user_summary + personality_traits=end_user.personality_traits + core_values=end_user.core_values + one_sentence_summary=end_user.one_sentence_summary + if language_type!='zh': + user_summary=await Translation_English(model_id, user_summary) + personality_traits = await Translation_English(model_id, personality_traits) + core_values = await Translation_English(model_id, core_values) + one_sentence_summary = await Translation_English(model_id, one_sentence_summary) has_cache = any([ - end_user.user_summary, - end_user.personality_traits, - end_user.core_values, - end_user.one_sentence_summary + user_summary, + personality_traits, + core_values, + one_sentence_summary ]) if has_cache: logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要") return { - "user_summary": end_user.user_summary, - "personality": end_user.personality_traits, - "core_values": end_user.core_values, - "one_sentence": end_user.one_sentence_summary, + "user_summary": user_summary, + "personality": personality_traits, + "core_values":core_values, + "one_sentence": one_sentence_summary, "updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at), "is_cached": True } @@ -1367,7 +1386,6 @@ async def analytics_memory_types( return memory_types - async def analytics_graph_data( db: Session, end_user_id: str, @@ -1557,7 +1575,7 @@ async def analytics_graph_data( f"成功获取图数据: end_user_id={end_user_id}, " f"nodes={len(nodes)}, edges={len(edges)}" ) - + return { "nodes": nodes, "edges": edges, @@ -1606,11 +1624,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ # 获取该节点类型的白名单字段 allowed_fields = field_whitelist.get(label, []) - - # 如果没有定义白名单,返回空字典(或者可以返回所有字段) - # if not allowed_fields: - # # 对于未定义的节点类型,只返回基本字段 - # allowed_fields = ["name", "created_at", "caption"] + count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;""" node_results = await (_neo4j_connector.execute_query(count_neo4j)) # 提取白名单中的字段 @@ -1618,13 +1632,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ for field in allowed_fields: if field in properties: value = properties[field] - if str(field) == 'entity_type': + if str(field) == 'entity_type': value=type_mapping.get(value,'') if str(field)=="emotion_type": value=EmotionType.EMOTION_MAPPING.get(value) - if str(field)=="emotion_subject": + if str(field)=="emotion_subject": value=EmotionSubject.SUBJECT_MAPPING.get(value) - # 清理 Neo4j 特殊类型 filtered_props[field] = _clean_neo4j_value(value) filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0] return filtered_props From 2e504f9c485aef4411c5d5621bb7e9f43cdfd6e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 13:55:32 +0800 Subject: [PATCH 2/7] Feature/distinction role (#165) * [feature]A set of information for role recognition writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [changes]Disable the function of batch writing multiple groups of conversations in a cumulative manner * [fix]Addressing vulnerability risks --- .../controllers/memory_agent_controller.py | 27 ++- api/app/core/agent/langchain_agent.py | 217 +++++++++++------- .../langgraph_graph/nodes/write_nodes.py | 40 ++-- .../agent/langgraph_graph/write_graph.py | 13 +- .../core/memory/agent/utils/get_dialogs.py | 63 ++--- .../core/memory/agent/utils/write_tools.py | 9 +- .../core/memory/llm_tools/chunker_client.py | 188 +++++++-------- .../core/memory/llm_tools/openai_client.py | 15 +- api/app/core/memory/models/graph_models.py | 7 + api/app/core/memory/models/message_models.py | 30 +-- .../extraction_orchestrator.py | 20 +- .../knowledge_extraction/chunk_extraction.py | 55 ++--- .../statement_extraction.py | 42 ++-- api/app/repositories/neo4j/add_nodes.py | 6 +- api/app/schemas/memory_agent_schema.py | 2 +- api/app/services/memory_agent_service.py | 76 +++++- api/app/tasks.py | 10 +- 17 files changed, 490 insertions(+), 330 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 46fe3043..416ed710 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -160,9 +160,12 @@ async def write_server( api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + result = await memory_agent_service.write_memory( user_input.group_id, - user_input.message, + messages_list, # 传递结构化消息列表 config_id, db, storage_type, @@ -219,9 +222,12 @@ async def write_server_async( if knowledge: user_rag_memory_id = str(knowledge.id) api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + task = celery_app.send_task( "app.core.memory.agent.write_message", - args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id] + args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Write task queued: {task.id}") @@ -564,8 +570,23 @@ async def status_type( """ api_logger.info(f"Status type check requested for group {user_input.group_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + + # 将消息列表转换为字符串用于分类 + # 只取最后一条用户消息进行分类 + last_user_message = "" + for msg in reversed(messages_list): + if msg.get('role') == 'user': + last_user_message = msg.get('content', '') + break + + if not last_user_message: + # 如果没有用户消息,使用所有消息的内容 + last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) + result = await memory_agent_service.classify_message_type( - user_input.message, + last_user_message, user_input.config_id, db ) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 91445b12..87b46e6f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -145,44 +145,98 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - async def term_memory_save(self,messages,end_user_end,aimessages): - '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - end_user_end=f"Term_{end_user_end}" - print(messages) - print(aimessages) - session_id = store.save_session( - userid=end_user_end, - messages=messages, - apply_id=end_user_end, - group_id=end_user_end, - aimessages=aimessages - ) - store.delete_duplicate_sessions() - # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - return session_id - async def term_memory_redis_read(self,end_user_end): - end_user_end = f"Term_{end_user_end}" - history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # logger.info(f'Redis_Agent:{end_user_end};{history}') - messagss_list=[] - retrieved_content=[] - for messages in history: - query = messages.get("Query") - aimessages = messages.get("Answer") - messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - retrieved_content.append({query: aimessages}) - return messagss_list,retrieved_content +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_save(self,messages,end_user_end,aimessages): + # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' + # end_user_end=f"Term_{end_user_end}" + # print(messages) + # print(aimessages) + # session_id = store.save_session( + # userid=end_user_end, + # messages=messages, + # apply_id=end_user_end, + # group_id=end_user_end, + # aimessages=aimessages + # ) + # store.delete_duplicate_sessions() + # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') + # return session_id + +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_redis_read(self,end_user_end): + # end_user_end = f"Term_{end_user_end}" + # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) + # # logger.info(f'Redis_Agent:{end_user_end};{history}') + # messagss_list=[] + # retrieved_content=[] + # for messages in history: + # query = messages.get("Query") + # aimessages = messages.get("Answer") + # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') + # retrieved_content.append({query: aimessages}) + # return messagss_list,retrieved_content - - async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): + async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): + """ + 写入记忆(支持结构化消息) + + Args: + storage_type: 存储类型 (neo4j/rag) + end_user_id: 终端用户ID + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') else: - write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, - user_rag_memory_id) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if user_message: + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if ai_message: + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + # 调用 Celery 任务,传递结构化消息列表 + # 数据流: + # 1. structured_messages 传递给 write_message_task + # 2. write_message_task 调用 memory_agent_service.write_memory + # 3. write_memory 调用 write_tools.write,传递 messages 参数 + # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 + # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 + # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 + logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # group_id: 用户ID + structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + actual_config_id, # config_id: 配置ID + storage_type, # storage_type: "neo4j" + user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') async def chat( self, @@ -227,29 +281,30 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') +# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# db_for_memory = next(get_db()) +# if memory_flag: +# if len(history_term_memory)>=4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# print(retrieved_content) +# # 为长期记忆操作获取新的数据库连接 +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# except Exception as e: +# logger.error(f"Failed to write to LongTermMemory: {e}") +# raise +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - db_for_memory = next(get_db()) - if memory_flag: - if len(history_term_memory)>=4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - print(retrieved_content) - # 为长期记忆操作获取新的数据库连接 - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - except Exception as e: - logger.error(f"Failed to write to LongTermMemory: {e}") - raise - finally: - db_for_memory.close() - - await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) - await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) +# # 长期记忆写入( +# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) +# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -277,8 +332,10 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id) - await self.term_memory_save(message_chat,end_user_id,content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, content) response = { "content": content, "model": self.model_name, @@ -346,27 +403,27 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") +# # TODO 乐力齐 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# if memory_flag: +# if len(history_term_memory) >= 4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# db_for_memory = next(get_db()) +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# # 长期记忆写入 +# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) +# except Exception as e: +# logger.error(f"Failed to write to long term memory: {e}") +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - if memory_flag: - if len(history_term_memory) >= 4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - db_for_memory = next(get_db()) - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, - history_term_memory, actual_config_id) - except Exception as e: - logger.error(f"Failed to write to long term memory: {e}") - finally: - db_for_memory.close() - - await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -418,8 +475,10 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") if memory_flag: - await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id) - await self.term_memory_save(message_chat, end_user_id, full_content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, full_content) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index 8421d059..6af313c3 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState: Write data to the database/file system. Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + state: WriteState containing messages, group_id, and memory_config Returns: - dict: Contains 'status', 'saved_to', and 'data' fields + dict: Contains 'write_result' with status and data fields """ - content=state.get('data','') - group_id=state.get('group_id','') - memory_config=state.get('memory_config', '') + messages = state.get('messages', []) + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', '') + + # Convert LangChain messages to structured format expected by write() + structured_messages = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # Map LangChain message types to role names + role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type + structured_messages.append({ + "role": role, + "content": msg.content # content is now guaranteed to be a string + }) + try: - result=await write( - content=content, + result = await write( + messages=structured_messages, user_id=group_id, apply_id=group_id, group_id=group_id, @@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState: ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - write_result= { + write_result = { "status": "success", - "data": content, + "data": structured_messages, "config_id": memory_config.config_id, "config_name": memory_config.config_name, } - return {"write_result":write_result} - + return {"write_result": write_result} except Exception as e: logger.error(f"Data_write failed: {e}", exc_info=True) - write_result= { + write_result = { "status": "error", "message": str(e), } diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 5a6f1e28..fe281a23 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -14,7 +14,6 @@ from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -27,18 +26,12 @@ async def make_write_graph(): """ Create a write graph workflow for memory operations. - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + The workflow directly processes messages from the initial state + and saves them to Neo4j storage. """ workflow = StateGraph(WriteState) - workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "content_input") - workflow.add_edge("content_input", "save_neo4j") + workflow.add_edge(START, "save_neo4j") workflow.add_edge("save_neo4j", END) graph = workflow.compile() diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index b03fe57c..82a41773 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -12,32 +12,49 @@ async def get_chunked_dialogs( group_id: str = "group_1", user_id: str = "user1", apply_id: str = "applyid", - content: str = "这是用户的输入", + messages: list = None, ref_id: str = "wyl_20251027", config_id: str = None ) -> List[DialogData]: - """Generate chunks from all test data entries using the specified chunker strategy. + """Generate chunks from structured messages using the specified chunker strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) group_id: Group identifier user_id: User identifier apply_id: Application identifier - content: Dialog content + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier config_id: Configuration ID for processing Returns: - List of DialogData objects with generated chunks for each test entry + List of DialogData objects with generated chunks """ - dialog_data_list = [] - messages = [] - - messages.append(ConversationMessage(role="用户", msg=content)) - - # Create DialogData - conversation_context = ConversationContext(msgs=messages) - # Create DialogData with group_id based on the entry's id for uniqueness + from app.core.logging_config import get_agent_logger + logger = get_agent_logger(__name__) + + if not messages or not isinstance(messages, list) or len(messages) == 0: + raise ValueError("messages parameter must be a non-empty list") + + conversation_messages = [] + + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") + + role = msg['role'] + content = msg['content'] + + if role not in ['user', 'assistant']: + raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") + + if content.strip(): + conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + + if not conversation_messages: + raise ValueError("Message list cannot be empty after filtering") + + conversation_context = ConversationContext(msgs=conversation_messages) dialog_data = DialogData( context=conversation_context, ref_id=ref_id, @@ -46,25 +63,11 @@ async def get_chunked_dialogs( apply_id=apply_id, config_id=config_id ) - # Create DialogueChunker and process the dialogue + chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks + + logger.info(f"DialogData created with {len(extracted_chunks)} chunks") - dialog_data_list.append(dialog_data) - - # Convert to dict with datetime serialized - def serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - combined_output = [dd.model_dump() for dd in dialog_data_list] - - print(dialog_data_list) - - # with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f: - # json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime) - - - return dialog_data_list + return [dialog_data] diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 53c941ad..1df0b336 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -29,25 +29,22 @@ logger = get_agent_logger(__name__) async def write( - content: str, user_id: str, apply_id: str, group_id: str, memory_config: MemoryConfig, + messages: list, ref_id: str = "wyl20251027", ) -> None: """ Execute the complete knowledge extraction pipeline. - Only MemoryConfig is needed - LLM and embedding clients are constructed - internally from the config. - Args: - content: Dialogue content to process user_id: User identifier apply_id: Application identifier group_id: Group identifier memory_config: MemoryConfig object containing all configuration + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference ID, defaults to "wyl20251027" """ # Extract config values @@ -89,7 +86,7 @@ async def write( group_id=group_id, user_id=user_id, apply_id=apply_id, - content=content, + messages=messages, ref_id=ref_id, config_id=config_id, ) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 4178ce0a..87cdb9f4 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -4,6 +4,7 @@ import os import asyncio import json import numpy as np +import logging # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk try: from app.core.memory.llm_tools.openai_client import OpenAIClient except Exception: - # 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入 OpenAIClient = Any +# Initialize logger +logger = logging.getLogger(__name__) + class LLMChunker: - """基于LLM的智能分块策略""" + """LLM-based intelligent chunking strategy""" def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): self.llm_client = llm_client self.chunk_size = chunk_size async def __call__(self, text: str) -> List[Any]: - # 使用LLM分析文本结构并进行智能分块 prompt = f""" - 请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 - 请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。 + Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long. + Return results in JSON format with a chunks array, each chunk having a text field. - 文本内容: + Text content: {text[:5000]} """ messages = [ - {"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, + {"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."}, {"role": "user", "content": prompt} ] @@ -171,8 +173,6 @@ class ChunkerClient: base_chunk_size=self.chunk_size, ) elif chunker_config.chunker_strategy == "SentenceChunker": - # 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数 - # 为了兼容不同版本,这里仅传递广泛支持的参数 self.chunker = SentenceChunker( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -186,100 +186,93 @@ class ChunkerClient: async def generate_chunks(self, dialogue: DialogData): """ - 生成分块,支持异步操作 + Generate chunks following 1 Message = 1 Chunk strategy. + + Each message creates one chunk, directly inheriting role information. + If a message is too long, it will be split into multiple sub-chunks, + each maintaining the same speaker. + + Raises: + ValueError: If dialogue has no messages or chunking fails """ - try: - # 预处理文本:确保对话标记格式统一 - content = dialogue.content - content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号 - content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 - - if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__): - # 同步分块器 - chunks = self.chunker(content) + # Validate dialogue has messages + if not dialogue.context or not dialogue.context.msgs: + raise ValueError( + f"Dialogue {dialogue.ref_id} has no messages. " + f"Cannot generate chunks from empty dialogue." + ) + + dialogue.chunks = [] + + # 按消息分块:每个消息创建一个或多个 chunk,直接继承角色 + for msg_idx, msg in enumerate(dialogue.context.msgs): + # Validate message has required attributes + if not hasattr(msg, 'role') or not hasattr(msg, 'msg'): + raise ValueError( + f"Message {msg_idx} in dialogue {dialogue.ref_id} " + f"missing 'role' or 'msg' attribute" + ) + + msg_content = msg.msg.strip() + + # Skip empty messages + if not msg_content: + continue + + # 如果消息太长,可以进一步分块 + if len(msg_content) > self.chunk_size: + # 对单个消息的内容进行分块 + try: + sub_chunks = self.chunker(msg_content) + except Exception as e: + raise ValueError( + f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}" + ) + + for idx, sub_chunk in enumerate(sub_chunks): + sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk) + sub_chunk_text = sub_chunk_text.strip() + + if len(sub_chunk_text) < (self.min_characters_per_chunk or 50): + continue + + chunk = Chunk( + content=f"{msg.role}: {sub_chunk_text}", + speaker=msg.role, # 直接继承角色 + metadata={ + "message_index": msg_idx, + "message_role": msg.role, + "sub_chunk_index": idx, + "total_sub_chunks": len(sub_chunks), + "chunker_strategy": self.chunker_config.chunker_strategy, + }, + ) + dialogue.chunks.append(chunk) else: - # 异步分块器(如LLMChunker) - chunks = await self.chunker(content) - - # 过滤空块和过小的块 - valid_chunks = [] - for c in chunks: - chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c - if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50): - valid_chunks.append(c) - - dialogue.chunks = [ - Chunk( - content=c.text if hasattr(c, 'text') else str(c), + # 消息不长,直接作为一个 chunk + chunk = Chunk( + content=f"{msg.role}: {msg_content}", + speaker=msg.role, # 直接继承角色 metadata={ - "start_index": getattr(c, "start_index", None), - "end_index": getattr(c, "end_index", None), + "message_index": msg_idx, + "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, ) - for c in valid_chunks - ] - return dialogue - - except Exception as e: - print(f"分块失败: {e}") - - # 改进的后备方案:尝试按对话回合分割 - try: - # 简单的按对话分割 - dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' - matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL) - - class SimpleChunk: - def __init__(self, text, start_index, end_index): - self.text = text - self.start_index = start_index - self.end_index = end_index - - chunks = [] - current_chunk = "" - current_start = 0 - - for match in matches: - speaker, ct = match[0], match[1].strip() - turn_text = f"{speaker} {ct}" - - if len(current_chunk) + len(turn_text) > (self.chunk_size or 500): - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - current_chunk = turn_text - current_start = dialogue.content.find(turn_text, current_start) - else: - current_chunk += ("\n" + turn_text) if current_chunk else turn_text - - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - - dialogue.chunks = [ - Chunk( - content=c.text, - metadata={ - "start_index": c.start_index, - "end_index": c.end_index, - "chunker_strategy": "DialogueTurnFallback", - }, - ) - for c in chunks - ] - - except Exception: - # 最后的手段:单一大块 - dialogue.chunks = [Chunk( - content=dialogue.content, - metadata={"chunker_strategy": "SingleChunkFallback"}, - )] - - return dialogue + dialogue.chunks.append(chunk) + + # Validate we generated at least one chunk + if not dialogue.chunks: + raise ValueError( + f"No valid chunks generated for dialogue {dialogue.ref_id}. " + f"All messages were either empty or too short. " + f"Messages count: {len(dialogue.context.msgs)}" + ) + + return dialogue def evaluate_chunking(self, dialogue: DialogData) -> dict: - """ - 评估分块质量 - """ + """Evaluate chunking quality.""" if not getattr(dialogue, 'chunks', None): return {} @@ -304,11 +297,8 @@ class ChunkerClient: return metrics def save_chunking_results(self, dialogue: DialogData, output_path: str): - """ - 保存分块结果到文件,文件名包含策略名称 - """ + """Save chunking results to file with strategy name in filename.""" strategy_name = self.chunker_config.chunker_strategy - # 在文件名中添加策略名称 base_name, ext = os.path.splitext(output_path) strategy_output_path = f"{base_name}_{strategy_name}{ext}" diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index dce7b495..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -92,8 +92,6 @@ class OpenAIClient(LLMClient): config["callbacks"] = [self.langfuse_handler] response = await chain.ainvoke({"messages": messages}, config=config) - - logger.debug(f"LLM 响应成功: {len(str(response))} 字符") return response except Exception as e: @@ -149,13 +147,10 @@ class OpenAIClient(LLMClient): config=config ) - logger.debug(f"使用 PydanticOutputParser 解析成功") return parsed except Exception as e: - logger.warning( - f"PydanticOutputParser 解析失败,尝试其他方法: {e}" - ) + logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}") # 方法 2: 使用 LangChain 的 with_structured_output template = """{question}""" @@ -173,13 +168,17 @@ class OpenAIClient(LLMClient): # 验证并返回结果 try: - return response_model.model_validate(parsed) + result = response_model.model_validate(parsed) + return result except Exception: # 如果已经是 Pydantic 实例,直接返回 if hasattr(parsed, "model_dump"): return parsed # 尝试从 JSON 解析 - return response_model.model_validate_json(json.dumps(parsed)) + result = response_model.model_validate_json(json.dumps(parsed)) + return result + else: + logger.warning("with_structured_output 方法不可用") except Exception as e: logger.error(f"结构化输出失败: {e}") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 39d618fc..7a48d6cb 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -224,6 +224,7 @@ class StatementNode(Node): chunk_id: ID of the parent chunk this statement belongs to stmt_type: Type of the statement (from ontology) statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses) emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node emotion_target: Optional emotion target (person or object name) emotion_subject: Optional emotion subject (self/other/object) @@ -249,6 +250,12 @@ class StatementNode(Node): stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") + # Speaker identification + speaker: Optional[str] = Field( + None, + description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" + ) + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( None, diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 199bdd75..bcf08999 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -25,10 +25,10 @@ class ConversationMessage(BaseModel): """Represents a single message in a conversation. Attributes: - role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant) + role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant) msg: Text content of the message """ - role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').") + role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") @@ -57,6 +57,7 @@ class Statement(BaseModel): chunk_id: ID of the parent chunk this statement belongs to group_id: Optional group ID for multi-tenancy statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) statement_embedding: Optional embedding vector for the statement stmt_type: Type of the statement (from ontology) temporal_info: Temporal information extracted from the statement @@ -74,6 +75,7 @@ class Statement(BaseModel): chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") statement: str = Field(..., description="The text content of the statement.") + speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") stmt_type: StatementType = Field(..., description="The type of the statement.") temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.") @@ -118,36 +120,36 @@ class Chunk(BaseModel): Attributes: id: Unique identifier for the chunk - text: List of messages in the chunk content: The content of the chunk as a formatted string + speaker: The speaker/role for this chunk (user/assistant) statements: List of statements extracted from this chunk chunk_embedding: Optional embedding vector for the chunk metadata: Additional metadata as key-value pairs """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.") - text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.") content: str = Field(..., description="The content of the chunk as a string.") + speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod - def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None): - """Create a chunk from a list of messages. + def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None): + """Create a chunk from a single message (1 Message = 1 Chunk). Args: - messages: List of conversation messages + message: Single conversation message metadata: Optional metadata dictionary Returns: - Chunk instance with formatted content + Chunk instance with speaker directly from message.role """ - if metadata is None: - metadata = {} - # Generate content from messages - content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages]) - return cls(text=messages, content=content, metadata=metadata) - + return cls( + content=f"{message.role}: {message.msg}", + speaker=message.role, + metadata=metadata or {} + ) + class DialogData(BaseModel): """Represents the complete data structure for a dialog record. diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 75aaa7df..46ba1dde 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -550,7 +550,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取情绪信息(优化版:全局陈述句级并行) + 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) Args: dialog_data_list: 对话数据列表 @@ -558,7 +558,7 @@ class ExtractionOrchestrator: Returns: 情绪信息映射列表,每个对话对应一个字典 """ - logger.info("开始情绪信息提取(全局陈述句级并行)") + logger.info("开始情绪信息提取(仅处理用户消息)") # 收集所有陈述句及其配置 all_statements = [] @@ -597,15 +597,22 @@ class ExtractionOrchestrator: if not data_config or not data_config.emotion_enabled: logger.info("情绪提取未启用,跳过") return [{} for _ in dialog_data_list] + + # 收集所有陈述句(只收集 speaker 为 "user" 的) + total_statements = 0 + filtered_statements = 0 - # 收集所有陈述句 for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: - all_statements.append((statement, data_config)) - statement_metadata.append((d_idx, statement.id)) + total_statements += 1 + # 只处理用户的陈述句 (role 为 "user") + if hasattr(statement, 'speaker') and statement.speaker == "user": + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + filtered_statements += 1 - logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪") # 初始化情绪提取服务 from app.services.emotion_extraction_service import EmotionExtractionService @@ -1033,6 +1040,7 @@ class ExtractionOrchestrator: apply_id=dialog_data.apply_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, + speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index edb60a4d..40e98507 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -22,12 +22,12 @@ class DialogueChunker: Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker """ self.chunker_strategy = chunker_strategy chunker_config_dict = get_chunker_config(chunker_strategy) self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) - # 对于 LLMChunker,需要传入 llm_client + if self.chunker_config.chunker_strategy == "LLMChunker": self.chunker_client = ChunkerClient(self.chunker_config, llm_client) else: @@ -41,29 +41,19 @@ class DialogueChunker: Returns: A list of Chunk objects + + Raises: + ValueError: If chunking fails or returns empty chunks """ result_dialogue = await self.chunker_client.generate_chunks(dialogue) - # Defensive fallback: ensure at least one chunk is returned for non-empty content - try: - chunks = result_dialogue.chunks - except Exception: - chunks = [] + chunks = result_dialogue.chunks if not chunks or len(chunks) == 0: - # If the dialogue has content, return a single fallback chunk built from messages - content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "") - if content_str and len(content_str.strip()) > 0: - fallback_chunk = Chunk.from_messages( - dialogue.context.msgs, - metadata={ - "fallback": "single_chunk", - "chunker_strategy": self.chunker_config.chunker_strategy, - "source": "DialogueChunkerFallback", - }, - ) - return [fallback_chunk] - # No content: return empty list - return [] + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) return chunks @@ -72,22 +62,25 @@ class DialogueChunker: Args: dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output (default: chunker_output_{strategy}.txt) + output_path: Optional path to save the output Returns: The path where the output was saved """ if not output_path: - output_path = os.path.join(os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt") + output_path = os.path.join( + os.path.dirname(__file__), "..", "..", + f"chunker_output_{self.chunker_strategy.lower()}.txt" + ) - output_lines = [] - output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===") - output_lines.append(f"Dialogue ID: {dialogue.ref_id}") - output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages") - output_lines.append(f"Total characters: {len(dialogue.content)}") - - output_lines.append(f"Generated {len(dialogue.chunks)} chunks:") + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs)} messages", + f"Total characters: {len(dialogue.content)}", + f"Generated {len(dialogue.chunks)} chunks:" + ] + for i, chunk in enumerate(dialogue.chunks): output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") output_lines.append(f" Content preview: {chunk.content}...") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index 8d37f5d2..fb1b539a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -5,8 +5,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional from app.core.memory.models.message_models import DialogData, Statement - -#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 from app.core.memory.models.variate_config import StatementExtractionConfig from app.core.memory.utils.data.ontology import ( LABEL_DEFINITIONS, @@ -22,11 +20,10 @@ logger = logging.getLogger(__name__) class ExtractedStatement(BaseModel): """Schema for extracted statement from LLM""" statement: str = Field(..., description="The extracted statement text") - statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION") + statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION") temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") relevence: str = Field(..., description="RELEVANT or IRRELEVANT") -# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句) class StatementExtractionResponse(BaseModel): statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements") @@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel): return v class StatementExtractor: - """Class for extracting statements from dialog chunks using LLM (relations separated)""" + """Class for extracting statements from dialog chunks using LLM""" def __init__(self, llm_client: Any, config: StatementExtractionConfig = None): - # 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 """Initialize the StatementExtractor with an LLM client and configuration Args: @@ -71,6 +67,21 @@ class StatementExtractor: self.llm_client = llm_client self.config = config or StatementExtractionConfig() + def _get_speaker_from_chunk(self, chunk) -> Optional[str]: + """Get speaker directly from Chunk + + Args: + chunk: Chunk object containing speaker field + + Returns: + Speaker role ("user"/"assistant") or None if cannot be determined + """ + if hasattr(chunk, 'speaker') and chunk.speaker: + return chunk.speaker + + logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") + return None + async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements @@ -82,10 +93,12 @@ class StatementExtractor: Returns: List of ExtractedStatement objects extracted from the chunk """ - # Prepare the chunk content for processing chunk_content = chunk.content + + if not chunk_content or len(chunk_content.strip()) < 5: + logger.warning(f"Chunk {chunk.id} content too short or empty, skipping") + return [] - # Render the prompt using helper function prompt_content = await render_statement_extraction_prompt( chunk_content=chunk_content, definitions=LABEL_DEFINITIONS, @@ -136,7 +149,9 @@ class StatementExtractor: relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT except (KeyError, ValueError): relevence_info = RelevenceInfo.RELEVANT - + + chunk_speaker = self._get_speaker_from_chunk(chunk) + chunk_statement = Statement( statement=extracted_stmt.statement, stmt_type=stmt_type, @@ -144,7 +159,9 @@ class StatementExtractor: relevence_info=relevence_info, chunk_id=chunk.id, group_id=group_id, + speaker=chunk_speaker, ) + chunk_statements.append(chunk_statement) # 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata @@ -226,12 +243,7 @@ class StatementExtractor: return output_path def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str: - """按对话分组聚合强/弱关系并写入 TXT 文件。 - - 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content` - - 在该对话段内再分为 Strong Relations / Weak Relations 两部分 - - Strong: 逐条输出 `Chunk ID` 与 `Triple` - - Weak: 逐条输出 `Chunk ID` 与 `Entity` - """ + """Group and aggregate strong/weak relations by dialogue and write to TXT file.""" print("\n=== Relations Classify ===") # 使用全局配置的输出路径 diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 1e24eeae..cf60a773 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": statement.speaker if hasattr(statement, 'speaker') else None, # 添加情绪字段处理 "emotion_type": statement.emotion_type, "emotion_intensity": statement.emotion_intensity, @@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "sequence_number": chunk.sequence_number, "start_index": metadata.get("start_index"), - "end_index": metadata.get("end_index") + "end_index": metadata.get("end_index"), + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": chunk.speaker if hasattr(chunk, 'speaker') else None } flattened_chunks.append(flattened_chunk) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 47dc6b2a..fbc0e45c 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -12,7 +12,7 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): - message: str + messages: list[dict] group_id: str config_id: Optional[str] = None diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index fd0cb0eb..65dd628a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -20,11 +20,13 @@ from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_base_service import Translation_English from app.services.memory_config_service import MemoryConfigService @@ -260,13 +262,13 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: group_id: Group identifier (also used as end_user_id) - message: Message to write + messages: Structured message list [{"role": "user", "content": "..."}, ...] config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -287,7 +289,7 @@ class MemoryAgentService: raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): - raise # Re-raise our specific error + raise logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") @@ -315,14 +317,28 @@ class MemoryAgentService: try: if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) + # For RAG storage, convert messages to single string + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + result = await write_rag(group_id, message_text, user_rag_memory_id) return result else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # Convert structured messages to LangChain messages + langchain_messages = [] + for msg in messages: + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + from langchain_core.messages import AIMessage + langchain_messages.append(AIMessage(content=msg['content'])) + # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, - "memory_config": memory_config} + initial_state = { + "messages": langchain_messages, + "group_id": group_id, + "memory_config": memory_config + } # 获取节点更新信息 async for update_event in graph.astream( @@ -335,7 +351,9 @@ class MemoryAgentService: massages = node_data massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') - return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + # Convert messages back to string for logging + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" @@ -531,7 +549,49 @@ class MemoryAgentService: ) raise ValueError(error_msg) - + def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: + """ + Get standardized message list from user input. + + Args: + user_input: Write_UserInput object + + Returns: + list[dict]: Message list, each message contains role and content + + Raises: + ValueError: If messages is empty or format is incorrect + """ + from app.core.logging_config import get_api_logger + logger = get_api_logger() + + if len(user_input.messages) == 0: + logger.error("Validation failed: Message list cannot be empty") + raise ValueError("Message list cannot be empty") + + for idx, msg in enumerate(user_input.messages): + if not isinstance(msg, dict): + logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") + raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + + if 'role' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") + + if 'content' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") + + if msg['role'] not in ['user', 'assistant']: + logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") + raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") + + if not msg['content'] or not msg['content'].strip(): + logger.error(f"Validation failed: Message {idx} content is empty") + raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") + + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") + return user_input.messages async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index fba9f290..e375de35 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -472,13 +472,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: +def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. + 支持两种消息格式: + 1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy" + 2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}] + Args: group_id: Group ID for the memory agent (also used as end_user_id) - message: Message to write + message: Message to write (str or list[dict]) config_id: Optional configuration ID + storage_type: Storage type (neo4j/rag) + user_rag_memory_id: RAG memory ID Returns: Dict containing the result and metadata From 37ef497f4cbd342c4ccb472efe808bdc4a6ca636 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:04:16 +0800 Subject: [PATCH 3/7] Feature/distinction role (#167) * [feature]A set of information for role recognition writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [changes]Disable the function of batch writing multiple groups of conversations in a cumulative manner * [fix]Addressing vulnerability risks * [fix]Fixing short-term memory writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [fix]Fixing short-term memory writing --- api/app/services/memory_agent_service.py | 51 ++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 65dd628a..692e9a9a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -518,6 +518,57 @@ class MemoryAgentService: optimized_outputs = merge_multiple_search_results(_intermediate_outputs) result = reorder_output_results(optimized_outputs) + # 保存短期记忆到数据库 + # 只有 search_switch 不为 "2"(快速检索)时才保存 + try: + from app.repositories.memory_short_repository import ShortTermMemoryRepository + + retrieved_content = [] + repo = ShortTermMemoryRepository(db) + + if str(search_switch) != "2": + for intermediate in _intermediate_outputs: + logger.debug(f"处理中间结果: {intermediate}") + intermediate_type = intermediate.get('type', '') + + if intermediate_type == "search_result": + query = intermediate.get('query', '') + raw_results = intermediate.get('raw_results', {}) + reranked_results = raw_results.get('reranked_results', []) + + try: + statements = [statement['statement'] for statement in reranked_results.get('statements', [])] + except Exception: + statements = [] + + # 去重 + statements = list(set(statements)) + + if query and statements: + retrieved_content.append({query: statements}) + + # 如果 retrieved_content 为空,设置为空字符串 + if retrieved_content == []: + retrieved_content = '' + + # 只有当回答不是"信息不足"且不是快速检索时才保存 + if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": + # 使用 upsert 方法 + repo.upsert( + end_user_id=group_id, + messages=message, + aimessages=summary, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}") + else: + logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") + + except Exception as save_error: + # 保存失败不应该影响主流程,只记录错误 + logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) + # Log successful operation if audit_logger: duration = time.time() - start_time From c24fb731472804b1f2a7266cdc46bdd8a9c5ebb4 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:58:46 +0800 Subject: [PATCH 4/7] Fix/memory celery fix (#168) * refactor(celery): optimize task routing and worker configuration - Simplify Celery queue configuration with single default 'io_tasks' queue - Implement task routing strategy separating IO-bound and CPU-bound tasks - Add Flower monitoring support with task event tracking enabled - Add summary node search optimization to only retrieve summary nodes - Clean up unused imports and reorganize import statements for consistency - Update docker-compose configuration to support multi-queue worker setup * chore(celery): simplify flower configuration and add gevent dependency * chore(dependencies): add gevent dependency to requirements - Add gevent==24.11.1 to api/requirements.txt - Gevent is required for async worker support in Celery - Complements existing flower and celery configuration * refactor(celery): simplify async event loop handling and reorganize task queues - Replace complex nest_asyncio and manual event loop management with asyncio.run() in read_message_task, write_message_task, regenerate_memory_cache, and workspace_reflection_task - Rename task queues from io_tasks/cpu_tasks to memory_tasks/document_tasks for better semantic clarity - Update task routing configuration to reflect new queue names for memory agent tasks and document processing tasks - Remove redundant exception handling comments and simplify error handling logic - Update README with improved community support section including GitHub Issues, Pull Requests, Discussions, and WeChat community links - Simplifies event loop management by leveraging asyncio.run() which handles loop creation and cleanup automatically, reducing code complexity and potential race conditions --- README.md | 13 +- api/app/celery_app.py | 75 ++++---- .../langgraph_graph/nodes/summary_nodes.py | 8 +- .../validators/memory_config_validators.py | 33 ++-- api/app/repositories/neo4j/graph_search.py | 92 ++++++---- api/app/services/draft_run_service.py | 15 +- api/app/services/memory_agent_service.py | 23 ++- api/app/services/memory_config_service.py | 33 ++-- api/app/tasks.py | 173 +++++------------- api/docker-compose.yml | 45 ++++- api/pyproject.toml | 2 + api/requirements.txt | 1 + 12 files changed, 254 insertions(+), 259 deletions(-) diff --git a/README.md b/README.md index 7d26f7f7..32a779d2 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface. ## License This project is licensed under the Apache License 2.0. For details, see the LICENSE file. -## Acknowledgements & Community -- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions. -- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines. -- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file +## Community & Support + +Join our community to ask questions, share your work, and connect with fellow developers. + +- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues). +- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls). +- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions). +- **WeChat**: Scan the QR code below to join our WeChat community group. +- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 85ad0643..185d746c 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,4 +1,5 @@ import os +import platform from datetime import timedelta from urllib.parse import quote @@ -14,27 +15,12 @@ celery_app = Celery( backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", ) -# 配置使用本地队列,避免与远程 worker 冲突 -celery_app.conf.task_default_queue = 'localhost_test_wyl' -celery_app.conf.task_default_exchange = 'localhost_test_wyl' -celery_app.conf.task_default_routing_key = 'localhost_test_wyl' +# Default queue for unrouted tasks +celery_app.conf.task_default_queue = 'memory_tasks' # macOS 兼容性配置 -import platform - -if platform.system() == 'Darwin': # macOS - # 设置环境变量解决 fork 问题 +if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') - - # 使用 solo 池避免多进程问题 - celery_app.conf.worker_pool = 'solo' - - # 设置唯一的节点名称 - import socket - import time - hostname = socket.gethostname() - timestamp = int(time.time()) - celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}" # Celery 配置 celery_app.conf.update( @@ -52,36 +38,47 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=30 * 60, # 30 分钟硬超时 - task_soft_time_limit=25 * 60, # 25 分钟软超时 + task_time_limit=1800, # 30分钟硬超时 + task_soft_time_limit=1500, # 25分钟软超时 - # Worker 设置 - 针对 macOS 优化 - worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 - worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏 - worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker + # Worker 设置 (per-worker settings are in docker-compose command line) + worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution # 结果过期时间 - result_expires=3600, # 结果保存 1 小时 + result_expires=3600, # 结果保存1小时 # 任务确认设置 - task_acks_late=True, # 任务完成后才确认,避免任务丢失 - worker_disable_rate_limits=True, # 禁用速率限制 + task_acks_late=True, + task_reject_on_worker_lost=True, + worker_disable_rate_limits=True, - # 任务路由(可选,用于不同队列) - # task_routes={ - # 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, - # 'app.core.memory.agent.read_message': {'queue': 'memory_processing'}, - # 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, - # 'tasks.process_item': {'queue': 'default'}, - # }, + # FLower setting + worker_send_task_events=True, + task_send_sent_event=True, + + # task routing + task_routes={ + # Memory tasks → memory_tasks queue (threads worker) + 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + + # Document tasks → document_tasks queue (prefork worker) + 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + + # Beat/periodic tasks → document_tasks queue (prefork worker) + 'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'}, + 'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'}, + 'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'}, + 'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'}, + }, ) # 自动发现任务模块 celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) -health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME @@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘 # 构建定时任务配置 beat_schedule_config = { - - # "check-read-service": { - # "task": "app.core.memory.agent.health.check_read_service", - # "schedule": health_schedule, - # "args": (), - # }, "run-workspace-reflection": { "task": "app.tasks.workspace_reflection_task", "schedule": workspace_reflection_schedule, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0d0b57b0..44f89c6a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -4,12 +4,11 @@ import os import time from app.core.logging_config import get_agent_logger, log_time -from app.db import get_db - from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, SummaryResponse, ) +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.search_service import SearchService from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, @@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) @@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState: search_params = { "group_id": group_id, "question": data, - "return_raw_results": True + "return_raw_results": True, + "include": ["summaries"] # Only search summary nodes for faster performance } try: diff --git a/api/app/core/validators/memory_config_validators.py b/api/app/core/validators/memory_config_validators.py index 6ccf3ddb..333572e6 100644 --- a/api/app/core/validators/memory_config_validators.py +++ b/api/app/core/validators/memory_config_validators.py @@ -89,14 +89,15 @@ def validate_model_exists_and_active( start_time = time.time() try: - # First check if model exists at all (without tenant filtering) - model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) - - # Then check with tenant filtering + # OPTIMIZED: Single query with tenant filter + # We'll check tenant mismatch in the error handling model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) elapsed_ms = (time.time() - start_time) * 1000 if not model: + # Model not found with tenant filter - check if it exists without filter + model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) + if model_without_tenant: # Model exists but belongs to different tenant logger.warning( @@ -208,8 +209,11 @@ def validate_embedding_model( db: Session, tenant_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None -) -> UUID: - """Validate that embedding model is available and return its UUID. +) -> tuple[UUID, str]: + """Validate that embedding model is available and return its UUID and name. + + Returns: + Tuple of (embedding_uuid, embedding_name) Raises: InvalidConfigError: If embedding_id is not provided or invalid @@ -225,14 +229,19 @@ def validate_embedding_model( workspace_id=workspace_id ) - embedding_uuid, _ = validate_and_resolve_model_id( + embedding_uuid, embedding_name = validate_and_resolve_model_id( embedding_id, "embedding", db, tenant_id, required=True, config_id=config_id, workspace_id=workspace_id ) - print(100*'-') - print(embedding_uuid) - print(_) - print(100*'-') + + logger.debug( + "Embedding model validated", + extra={ + "embedding_uuid": str(embedding_uuid), + "embedding_name": embedding_name, + "config_id": config_id + } + ) if embedding_uuid is None: raise InvalidConfigError( @@ -243,7 +252,7 @@ def validate_embedding_model( workspace_id=workspace_id ) - return embedding_uuid + return embedding_uuid, embedding_name def validate_llm_model( diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 0b6a27c6..6f5764b4 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -305,12 +305,19 @@ async def search_graph( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) + if needs_activation_update: + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + return results @@ -339,7 +346,7 @@ async def search_graph_by_embedding( embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") + logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: return {"statements": [], "chunks": [], "entities": [], "summaries": []} @@ -393,7 +400,7 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { @@ -417,14 +424,23 @@ async def search_graph_by_embedding( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - update_start = time.time() - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) - update_time = time.time() - update_start - print(f"[PERF] Activation value updates took: {update_time:.4f}s") + + if needs_activation_update: + update_start = time.time() + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + update_time = time.time() - update_start + logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") + else: + logger.info(f"[PERF] Skipping activation updates (only summaries)") return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 @@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - print(f"query_text不能为空") + logger.warning(f"query_text cannot be empty") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal( invalid_date=invalid_date, limit=limit, ) - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal keyword search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -594,9 +610,9 @@ async def search_graph_by_temporal( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") + logger.debug(f"Temporal search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -623,7 +639,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - print(f"dialog_id不能为空") + logger.warning(f"dialog_id cannot be empty") return {"dialogues": []} dialogues = await connector.execute_query( @@ -642,7 +658,7 @@ async def search_graph_by_chunk_id( limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - print(f"chunk_id不能为空") + logger.warning(f"chunk_id cannot be empty") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -679,9 +695,9 @@ async def search_graph_by_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -719,9 +735,9 @@ async def search_graph_by_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -759,9 +775,9 @@ async def search_graph_g_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -799,9 +815,9 @@ async def search_graph_g_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -839,9 +855,9 @@ async def search_graph_l_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -879,9 +895,9 @@ async def search_graph_l_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 50934226..46bda5f6 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,11 +10,6 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session - from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger from app.services.tool_service import ToolService +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "app.core.memory.agent.read_message", args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] ) - result = task_service.get_task_memory_read_result(task.id) - status = result.get("status") - logger.info(f"读取任务状态:{status}") + # result = task_service.get_task_memory_read_result(task.id) + # status = result.get("status") + # logger.info(f"读取任务状态:{status}") finally: db.close() diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 692e9a9a..6748d6c7 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -10,15 +10,17 @@ import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -import redis -from langchain_core.messages import HumanMessage +import redis from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results +from app.core.memory.agent.utils.messages_tools import ( + merge_multiple_search_results, + reorder_output_results, +) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags @@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -404,6 +407,7 @@ class MemoryAgentService: import time start_time = time.time() + logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") # Resolve config_id if None using end_user's connected config if config_id is None: @@ -427,13 +431,15 @@ class MemoryAgentService: audit_logger = None + config_load_start = time.time() try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + config_load_time = time.time() - config_load_start + logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) @@ -457,6 +463,7 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow + graph_exec_start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} @@ -513,6 +520,9 @@ class MemoryAgentService: if summary_n and summary_n != [] and summary_n != {}: _intermediate_outputs.append(summary_n) + graph_exec_time = time.time() - graph_exec_start + logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] optimized_outputs = merge_multiple_search_results(_intermediate_outputs) @@ -570,6 +580,8 @@ class MemoryAgentService: logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation + total_time = time.time() - start_time + logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -587,7 +599,8 @@ class MemoryAgentService: except Exception as e: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" - logger.error(error_msg) + total_time = time.time() - start_time + logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 09e980a0..0099eb18 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -125,7 +125,11 @@ class MemoryConfigService: try: validated_config_id = _validate_config_id(config_id) + # Step 1: Get config and workspace + db_query_start = time.time() result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) + db_query_time = time.time() - db_query_start + logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") if not result: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -144,16 +148,20 @@ class MemoryConfigService: memory_config, workspace = result - # Validate embedding model - embedding_uuid = validate_embedding_model( + # Step 2: Validate embedding model (returns both UUID and name) + embed_start = time.time() + embedding_uuid, embedding_name = validate_embedding_model( validated_config_id, memory_config.embedding_id, self.db, workspace.tenant_id, workspace.id, ) + embed_time = time.time() - embed_start + logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s") - # Resolve LLM model + # Step 3: Resolve LLM model + llm_start = time.time() llm_uuid, llm_name = validate_and_resolve_model_id( memory_config.llm_id, "llm", @@ -163,8 +171,11 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + llm_time = time.time() - llm_start + logger.info(f"[PERF] LLM validation: {llm_time:.4f}s") - # Resolve optional rerank model + # Step 4: Resolve optional rerank model + rerank_start = time.time() rerank_uuid = None rerank_name = None if memory_config.rerank_id: @@ -177,16 +188,12 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + rerank_time = time.time() - rerank_start + if memory_config.rerank_id: + logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") - # Get embedding model name - embedding_name, _ = validate_model_exists_and_active( - embedding_uuid, - "embedding", - self.db, - workspace.tenant_id, - config_id=validated_config_id, - workspace_id=workspace.id, - ) + # Note: embedding_name is now returned from validate_embedding_model above + # No need for redundant query! # Create immutable MemoryConfig object config = MemoryConfig( diff --git a/api/app/tasks.py b/api/app/tasks.py index e375de35..fa9d1fdf 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time return { @@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") @@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -600,53 +564,53 @@ def reflection_timer_task() -> None: """ reflection_engine() - -@celery_app.task(name="app.core.memory.agent.health.check_read_service") -def check_read_service_task() -> Dict[str, str]: - """Call read_service and write latest status to Redis. +# unused task +# @celery_app.task(name="app.core.memory.agent.health.check_read_service") +# def check_read_service_task() -> Dict[str, str]: +# """Call read_service and write latest status to Redis. - Returns status data dict that gets written to Redis. - """ - client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None - ) - try: - api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" - payload = { - "user_id": "健康检查", - "apply_id": "健康检查", - "group_id": "健康检查", - "message": "你好", - "history": [], - "search_switch": "2", - } - resp = requests.post(api_url, json=payload, timeout=15) - ok = resp.status_code == 200 - status = "Success" if ok else "Fail" - msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" - error = "" if ok else resp.text - code = 0 if ok else 500 - except Exception as e: - status = "Fail" - msg = "接口请求失败" - error = str(e) - code = 500 +# Returns status data dict that gets written to Redis. +# """ +# client = redis.Redis( +# host=settings.REDIS_HOST, +# port=settings.REDIS_PORT, +# db=settings.REDIS_DB, +# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None +# ) +# try: +# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" +# payload = { +# "user_id": "健康检查", +# "apply_id": "健康检查", +# "group_id": "健康检查", +# "message": "你好", +# "history": [], +# "search_switch": "2", +# } +# resp = requests.post(api_url, json=payload, timeout=15) +# ok = resp.status_code == 200 +# status = "Success" if ok else "Fail" +# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" +# error = "" if ok else resp.text +# code = 0 if ok else 500 +# except Exception as e: +# status = "Fail" +# msg = "接口请求失败" +# error = str(e) +# code = 500 - data = { - "status": status, - "msg": msg, - "error": error, - "code": str(code), - "time": str(int(time.time())), - } +# data = { +# "status": status, +# "msg": msg, +# "error": error, +# "code": str(code), +# "time": str(int(time.time())), +# } - client.hset("memsci:health:read_service", mapping=data) - client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) +# client.hset("memsci:health:read_service", mapping=data) +# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) - return data +# return data @celery_app.task(name="app.controllers.memory_storage_controller.search_all") @@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str "duration_seconds": duration } - # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_run()) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 8bc19f3a..a7337689 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -7,10 +7,6 @@ services: - "8002:8000" env_file: - .env - environment: - - SERVER_IP=0.0.0.0 - # 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位 - # - MCP_SERVER_URL= volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro @@ -19,20 +15,53 @@ services: networks: - default - celery + depends_on: + - worker-memory + - worker-document - # Celery worker - worker: + # Memory worker - Memory read/write tasks (threads pool for asyncio) + worker-memory: image: redbear-mem-open:latest - container_name: worker + container_name: worker-memory env_file: - .env volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro - command: celery -A app.celery_worker.celery_app worker --loglevel=info + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h restart: unless-stopped networks: - celery + # Document worker - Document parsing tasks (prefork for CPU-bound) + worker-document: + image: redbear-mem-open:latest + container_name: worker-document + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h + restart: unless-stopped + networks: + - celery + + # Celery Beat - scheduler + beat: + image: redbear-mem-open:latest + container_name: celery-beat + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app beat --loglevel=info + restart: unless-stopped + networks: + - celery + depends_on: + - worker-memory + networks: celery: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6da684de..414ba372 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "bcrypt==5.0.0", "billiard==4.2.2", "celery==5.5.3", + "flower==2.0.1", "cffi==2.0.0", "click==8.3.0", "click-didyoumean==0.3.1", @@ -138,6 +139,7 @@ dependencies = [ "python-calamine>=0.4.0", "xlrd==2.0.2", "deprecated>=1.3.1", + "flower>=2.0.1", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 99252e09..444a194b 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -6,6 +6,7 @@ async-timeout==5.0.1 bcrypt==5.0.0 billiard==4.2.2 celery==5.5.3 +flower==2.0.1 cffi==2.0.0 click==8.3.0 click-didyoumean==0.3.1 From 1e5acd85ffbed63cbddd70b54a2e7b8cb333dd28 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:11:50 +0800 Subject: [PATCH 5/7] Update community links in README.md --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 32a779d2..2f53a996 100644 --- a/README.md +++ b/README.md @@ -338,8 +338,9 @@ This project is licensed under the Apache License 2.0. For details, see the LICE Join our community to ask questions, share your work, and connect with fellow developers. -- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues). -- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls). -- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions). +- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues). +- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls). +- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions). - **WeChat**: Scan the QR code below to join our WeChat community group. -- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file +- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6) +- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com From b6e6dbf27f04cd577482a8e876342c9f6f65d9fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:20:28 +0800 Subject: [PATCH 6/7] Fix/memory interface (#169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [changes]《Modify the interface》 1.Remove the "/search/entity_graph" interface 2.Reconstruct the "/updated_end_user/profile" interface 3.Remove the "Update Username" interface 4.Fix the batch query of user association memory configuration * [changes]《Modify the interface》 1.Remove the "/search/entity_graph" interface 2.Reconstruct the "/updated_end_user/profile" interface 3.Remove the "Update Username" interface 4.Fix the batch query of user association memory configuration * [fix]Fix the error response type --- .../controllers/memory_agent_controller.py | 2 +- .../memory_dashboard_controller.py | 48 ---------- .../controllers/memory_storage_controller.py | 17 +--- .../controllers/user_memory_controllers.py | 70 ++++---------- .../repositories/data_config_repository.py | 32 ------- api/app/repositories/end_user_repository.py | 36 ------- api/app/schemas/memory_agent_schema.py | 4 - api/app/services/memory_agent_service.py | 8 +- api/app/services/memory_storage_service.py | 21 ---- api/app/services/user_memory_service.py | 95 +++++++++++++++++++ 10 files changed, 119 insertions(+), 214 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 416ed710..7707522c 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -682,7 +682,7 @@ async def get_user_profile_api( current_user: User = Depends(get_current_user) ): """ - 获取用户详情,包含: + 获取工作空间下Popular Memory Tags,包含: - name: 用户名字(直接使用 end_user_id) - tags: 3个用户特征标签(从语句和实体中LLM总结) - hot_tags: 4个热门记忆标签 diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 2afff491..e03c1846 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -5,7 +5,6 @@ from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User -from app.schemas.memory_agent_schema import End_User_Information from app.schemas.response_schema import ApiResponse from app.services import memory_dashboard_service, memory_storage_service, workspace_service @@ -40,54 +39,7 @@ def get_workspace_total_end_users( api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}") return success(data=total_end_users, msg="用户数量获取成功") -@router.post("/update/end_users", response_model=ApiResponse) -async def update_workspace_end_users( - user_input: End_User_Information, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 更新工作空间的宿主信息 - """ - username = user_input.end_user_name # 要更新的用户名 - end_user_input_id = user_input.id # 宿主ID - workspace_id = current_user.current_workspace_id - - api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息") - api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}") - try: - # 导入更新函数 - from app.repositories.end_user_repository import update_end_user_other_name - import uuid - - # 转换 end_user_id 为 UUID 类型 - end_user_uuid = uuid.UUID(end_user_input_id) - - # 直接更新数据库中的 other_name 字段 - updated_count = update_end_user_other_name( - db=db, - end_user_id=end_user_uuid, - other_name=username - ) - - api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}") - - return success( - data={ - "updated_count": updated_count, - "end_user_id": end_user_input_id, - "updated_other_name": username - }, - msg=f"成功更新 {updated_count} 个宿主的信息" - ) - - except Exception as e: - api_logger.error(f"更新宿主信息失败: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"更新宿主信息失败: {str(e)}" - ) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 63d9078a..f4175923 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -28,7 +28,6 @@ from app.services.memory_storage_service import ( search_dialogue, search_edges, search_entity, - search_entity_graph, search_statement, ) from fastapi import APIRouter, Depends @@ -412,21 +411,7 @@ async def search_entity_edges( api_logger.error(f"Search edges failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) -@router.get("/search/entity_graph", response_model=ApiResponse) -async def search_for_entity_graph( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - """ - 搜索所有实体之间的关系网络 - """ - api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}") - try: - result = await search_entity_graph(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search entity graph failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e)) + @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index d99eb47e..3b7345b6 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -351,12 +351,11 @@ async def update_end_user_profile( 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 所有字段都是可选的,只更新提供的字段。 - """ workspace_id = current_user.current_workspace_id end_user_id = profile_update.end_user_id - # 检查用户是否已选择工作空间 + # 验证工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") @@ -366,57 +365,24 @@ async def update_end_user_profile( f"workspace={workspace_id}" ) - try: - # 查询终端用户 - end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + # 调用 Service 层处理业务逻辑 + result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update) - if not end_user: - api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") - return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - - # 更新字段(只更新提供的字段,排除 end_user_id) - # 允许 None 值来重置字段(如 hire_date) - update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) - - # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime - if 'hire_date' in update_data: - hire_date_timestamp = update_data['hire_date'] - if hire_date_timestamp is not None: - update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) - # 如果是 None,保持 None(允许清空) - - for field, value in update_data.items(): - setattr(end_user, field, value) - - # 更新 updated_at 时间戳 - end_user.updated_at = datetime.datetime.now() - - # 更新 updatetime_profile 为当前时间 - end_user.updatetime_profile = datetime.datetime.now() - - # 提交更改 - db.commit() - db.refresh(end_user) - - # 构建响应数据 - profile_data = EndUserProfileResponse( - id=end_user.id, - other_name=end_user.other_name, - position=end_user.position, - department=end_user.department, - contact=end_user.contact, - phone=end_user.phone, - hire_date=end_user.hire_date, - updatetime_profile=end_user.updatetime_profile - ) - - api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") - return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功") - - except Exception as e: - db.rollback() - api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) + if result["success"]: + api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}") + return success(data=result["data"], msg="更新成功") + else: + error_msg = result["error"] + api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") + + # 根据错误类型映射到合适的业务错误码 + if error_msg == "终端用户不存在": + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) + elif error_msg == "无效的用户ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg) + else: + # 只有未预期的错误才使用 INTERNAL_ERROR + return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) @router.get("/memory_space/timeline_memories", response_model=ApiResponse) async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh", diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index d26058b2..3df7f800 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -104,38 +104,6 @@ class DataConfigRepository: r.statement AS statement """ - # Entity graph within group (source node, edge, target node) - SEARCH_FOR_ENTITY_GRAPH = """ - MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) - WHERE n.group_id = $group_id - RETURN - { - entity_idx: n.entity_idx, - connect_strength: n.connect_strength, - description: n.description, - entity_type: n.entity_type, - name: n.name, - fact_summary: COALESCE(n.fact_summary, ''), - id: n.id - } AS sourceNode, - { - rel_id: elementId(r), - source_id: startNode(r).id, - target_id: endNode(r).id, - predicate: r.predicate, - statement_id: r.statement_id, - statement: r.statement - } AS edge, - { - entity_idx: m.entity_idx, - connect_strength: m.connect_strength, - description: m.description, - entity_type: m.entity_type, - name: m.name, - fact_summary: COALESCE(m.fact_summary, ''), - id: m.id - } AS targetNode - """ @staticmethod def update_reflection_config( db: Session, diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index b9e82693..c7d13f8f 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser] end_user = repo.get_end_user_by_id(end_user_id) return end_user -def update_end_user_other_name( - db: Session, - end_user_id: uuid.UUID, - other_name: str -) -> int: - """ - 通过 end_user_id 更新 end_user 表中的 other_name 字段 - - Args: - db: 数据库会话 - end_user_id: 宿主ID - other_name: 要更新的用户名 - - Returns: - int: 更新的记录数 - """ - try: - # 执行更新 - updated_count = ( - db.query(EndUser) - .filter(EndUser.id == end_user_id) - .update( - {EndUser.other_name: other_name}, - synchronize_session=False - ) - ) - - db.commit() - db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}") - return updated_count - - except Exception as e: - db.rollback() - db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}") - raise - # 新增的缓存操作函数(保持与类方法一致的接口) def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据ID获取终端用户(用于缓存操作)""" diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index fbc0e45c..d4354c40 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -15,7 +15,3 @@ class Write_UserInput(BaseModel): messages: list[dict] group_id: str config_id: Optional[str] = None - -class End_User_Information(BaseModel): - end_user_name: str # 这是要更新的用户名 - id: str # 宿主ID,用于匹配条件 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 6748d6c7..d744b766 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1157,7 +1157,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) """ from app.models.app_release_model import AppRelease from app.models.end_user_model import EndUser - from app.models.memory_config_model import MemoryConfig + from app.models.data_config_model import DataConfig from sqlalchemy import select logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") @@ -1215,8 +1215,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 批量查询 memory_config_name config_id_to_name = {} if memory_config_ids: - memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() - config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} + memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs} # 4. 构建最终结果 for end_user_id, app_id in user_to_app.items(): @@ -1233,7 +1233,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称 - memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None + memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None result[end_user_id] = { "memory_config_id": memory_config_id, diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 9cac26ec..83d5923d 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -506,27 +506,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] return result -async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]: - """搜索所有实体之间的关系网络(group 维度)。""" - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH, - group_id=end_user_id, - ) - # 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”) - for item in result: - source_fact = item["sourceNode"]["fact_summary"] - target_fact = item["targetNode"]["fact_summary"] - # 截取前三条“来源” - item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else [] - item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else [] - # 与现有返回风格保持一致,携带搜索类型、数量与详情 - data = { - "search_for": "entity_graph", - "num": len(result), - "detials": result, - } - return data - async def analytics_hot_memory_tags( db: Session, diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index ae07256a..863bccb0 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -357,6 +357,101 @@ class UserMemoryService: data[key] = UserMemoryService._datetime_to_timestamp(original_value) return data + def update_end_user_profile( + self, + db: Session, + end_user_id: str, + profile_update: Any + ) -> Dict[str, Any]: + """ + 更新终端用户的基本信息 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + profile_update: 包含更新字段的 Pydantic 模型 + + Returns: + { + "success": bool, + "data": dict, # 更新后的用户档案数据 + "error": Optional[str] + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return { + "success": False, + "data": None, + "error": "终端用户不存在" + } + + # 获取更新数据(排除 end_user_id 字段) + update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) + + # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime + if 'hire_date' in update_data: + hire_date_timestamp = update_data['hire_date'] + if hire_date_timestamp is not None: + from app.core.api_key_utils import timestamp_to_datetime + update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) + # 如果是 None,保持 None(允许清空) + + # 更新字段 + for field, value in update_data.items(): + setattr(end_user, field, value) + + # 更新时间戳 + end_user.updated_at = datetime.now() + end_user.updatetime_profile = datetime.now() + + # 提交更改 + db.commit() + db.refresh(end_user) + + # 构建响应数据 + from app.schemas.end_user_schema import EndUserProfileResponse + profile_data = EndUserProfileResponse( + id=end_user.id, + other_name=end_user.other_name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") + + return { + "success": True, + "data": self.convert_profile_to_dict_with_timestamp(profile_data), + "error": None + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "data": None, + "error": "无效的用户ID格式" + } + except Exception as e: + db.rollback() + logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + async def get_cached_memory_insight( self, db: Session, From fb25495f1b44a5d6c62744113d463562cd00e2d3 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:21:51 +0800 Subject: [PATCH 7/7] Fix/memory mcp2 1 (#170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * feat(celery): add comprehensive logging to worker and write task - Initialize logging system in Celery worker entry point with LoggingConfig - Add logger instance and startup message to celery_worker.py - Reorganize imports in tasks.py for better readability and consistency - Add detailed logging to write_message_task for debugging and monitoring - Log task start with group_id, config_id, and storage_type parameters - Log service execution and completion status with results - Add exception handling with error logging and stack trace capture - Log task completion time and Celery task ID for performance tracking - Improves observability and troubleshooting of async task execution * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 快速检索,需要在接口部分添加LLM整合 * 快速检索,需要在接口部分添加LLM整合 --------- Co-authored-by: Ke Sun --- .../controllers/memory_agent_controller.py | 15 +++++ .../langgraph_graph/nodes/problem_nodes.py | 41 ++++++------ .../agent/langgraph_graph/read_graph.py | 1 - .../agent/services/optimized_llm_service.py | 4 +- api/app/services/memory_agent_service.py | 62 ++++++++++++++++++- 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 7707522c..22830890 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -9,6 +9,8 @@ from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.redis_tool import store from app.repositories import knowledge_repository, WorkspaceRepository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse @@ -291,6 +293,19 @@ async def read_server( storage_type, user_rag_memory_id ) + if str(user_input.search_switch) == "2": + retrieve_info = result['answer'] + history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) + query = user_input.message + + # 调用 memory_agent_service 的方法生成最终答案 + result['answer'] = await memory_agent_service.generate_summary_from_retrieve( + retrieve_info=retrieve_info, + history=history, + query=query, + config_id=config_id, + db=db + ) return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index e02ef62b..697a13bd 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -18,16 +18,19 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) + class ProblemNodeService(LLMServiceMixin): """问题处理节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 problem_service = ProblemNodeService() + async def Split_The_Problem(state: ReadState) -> ReadState: """问题分解节点""" # 从状态中获取数据 @@ -36,10 +39,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) history = await SessionService(store).get_history(group_id, group_id, group_id) - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() - + system_prompt = await problem_service.template_service.render_template( template_name='problem_breakdown_prompt.jinja2', operation_name='split_the_problem', @@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: sentence=content, json_schema=json_schema ) - + try: # 使用优化的LLM服务 structured = await problem_service.call_llm_structured( @@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: response_model=ProblemExtensionResponse, fallback_value=[] ) - + # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") - + # 验证结构化响应 if not structured or not hasattr(structured, 'root'): logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") @@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState: [item.model_dump() for item in structured.root], ensure_ascii=False ) - + split_result_dict = [] for index, item in enumerate(json.loads(split_result)): split_data = { - "id": f"Q{index+1}", + "id": f"Q{index + 1}", "question": item['extended_question'], "type": item['type'], "reason": item['reason'] } split_result_dict.append(split_data) - + logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") result = { @@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "original_query": content } } - + except Exception as e: logger.error( f"Split_The_Problem failed: {e}", exc_info=True ) - + # 提供更详细的错误信息 error_details = { "error_type": type(e).__name__, @@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "content_length": len(content), "llm_model_id": memory_config.llm_model_id if memory_config else None } - + logger.error(f"Split_The_Problem error details: {error_details}") - + # 创建默认的空结果 result = { "context": json.dumps([], ensure_ascii=False), @@ -126,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "error": error_details } } - + # 返回更新后的状态,包含spit_context字段 return {"spit_data": result} + async def Problem_Extension(state: ReadState) -> ReadState: """问题扩展节点""" # 获取原始数据和分解结果 @@ -153,10 +157,10 @@ async def Problem_Extension(state: ReadState) -> ReadState: data = [] history = await SessionService(store).get_history(group_id, group_id, group_id) - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() - + system_prompt = await problem_service.template_service.render_template( template_name='Problem_Extension_prompt.jinja2', operation_name='problem_extension', @@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState: } } - return {"problem_extension": result} - - - + return {"problem_extension": result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index c01889a9..19011a5f 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -59,7 +59,6 @@ async def make_read_graph(): workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py index 68919c4a..6942d421 100644 --- a/api/app/core/memory/agent/services/optimized_llm_service.py +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -162,7 +162,7 @@ class OptimizedLLMService: return fallback_value elif isinstance(fallback_value, dict): return response_model(**fallback_value) - + # 尝试创建空的响应模型 if hasattr(response_model, 'root'): # RootModel类型 @@ -170,7 +170,7 @@ class OptimizedLLMService: else: # 普通BaseModel类型 return response_model() - + except Exception as e: logger.error(f"创建降级响应失败: {e}") # 最后的降级策略 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index d744b766..8170bdd8 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -683,7 +683,67 @@ class MemoryAgentService: logger.debug(f"Message type: {status}") return status - # ==================== 新增的三个接口方法 ==================== + async def generate_summary_from_retrieve( + self, + retrieve_info: str, + history: List[Dict], + query: str, + config_id: str, + db: Session + ) -> str: + """ + 基于检索信息、历史对话和查询生成最终答案 + + 使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案 + + Args: + retrieve_info: 检索到的信息 + history: 历史对话记录 + query: 用户查询 + config_id: 配置ID + db: 数据库会话 + + Returns: + 生成的答案文本 + """ + logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") + + try: + # 加载配置 + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) + + # 导入必要的模块 + from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm + from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse + + # 构建状态对象 + state = { + "data": query, + "memory_config": memory_config + } + + # 直接调用 summary_llm 函数 + answer = await summary_llm( + state=state, + history=history, + retrieve_info=retrieve_info, + template_name='Retrieve_Summary_prompt.jinja2', + operation_name='retrieve_summary', + response_model=RetrieveSummaryResponse, + search_mode="1" + ) + + logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") + return answer if answer else "信息不足,无法回答。" + + except Exception as e: + logger.error(f"生成摘要失败: {str(e)}", exc_info=True) + return "信息不足,无法回答。" + async def get_knowledge_type_stats( self,