From a01525e239b00f3fbeed87c769eb8ada5b08863b Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 16 Apr 2026 13:27:36 +0800 Subject: [PATCH] refactor(memory): consolidate memory search services and update model client handling - Consolidate memory search services by removing separate content_search.py and perceptual_search.py - Update model client handling in base_pipeline.py to use ModelApiKeyService for LLM client initialization - Add new prompt files and modify existing services to support consolidated search architecture - Refactor memory read pipeline and related services to use updated model client approach --- api/app/core/memory/memory_service.py | 11 +- api/app/core/memory/models/service_models.py | 17 +- .../core/memory/pipelines/base_pipeline.py | 31 ++- api/app/core/memory/pipelines/memory_read.py | 34 ++- api/app/core/memory/prompt/__init__.py | 85 +++++++ .../core/memory/prompt/problem_split.jinja2 | 212 ++++++++++++++++ .../memory/read_services/content_search.py | 46 ++-- .../read_services/memory_search/__init__.py | 0 .../memory_search/content_search.py | 14 -- .../memory_search/perceptual_search.py | 228 ------------------ .../read_services/query_preprocessor.py | 29 ++- .../memory/read_services/result_builder.py | 2 +- .../memory/read_services/retrieval_summary.py | 11 + api/app/core/memory/utils/llm/llm_utils.py | 48 +++- api/app/repositories/neo4j/graph_search.py | 10 +- .../prompt/prompt_optimizer_system.jinja2 | 2 +- 16 files changed, 471 insertions(+), 309 deletions(-) create mode 100644 api/app/core/memory/prompt/__init__.py create mode 100644 api/app/core/memory/prompt/problem_split.jinja2 delete mode 100644 api/app/core/memory/read_services/memory_search/__init__.py delete mode 100644 api/app/core/memory/read_services/memory_search/content_search.py delete mode 100644 api/app/core/memory/read_services/memory_search/perceptual_search.py create mode 100644 api/app/core/memory/read_services/retrieval_summary.py diff --git a/api/app/core/memory/memory_service.py b/api/app/core/memory/memory_service.py index 67c814b1..15ea14b4 100644 --- a/api/app/core/memory/memory_service.py +++ b/api/app/core/memory/memory_service.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session from app.core.memory.enums import StorageType, SearchStrategy -from app.core.memory.models.service_models import Memory, MemoryContext +from app.core.memory.models.service_models import MemoryContext, MemorySearchResult from app.core.memory.pipelines.memory_read import ReadPipeLine from app.db import get_db_context from app.services.memory_config_service import MemoryConfigService @@ -35,9 +35,14 @@ class MemoryService: async def write(self, messages: list[dict]) -> str: raise NotImplementedError - async def read(self, query: str, history: list, search_switch: SearchStrategy) -> list[Memory]: + async def read( + self, + query: str, + search_switch: SearchStrategy, + limit: int = 10, + ) -> MemorySearchResult: with get_db_context() as db: - return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit=10) + return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit) async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict: raise NotImplementedError diff --git a/api/app/core/memory/models/service_models.py b/api/app/core/memory/models/service_models.py index 82a867c7..477c0ba8 100644 --- a/api/app/core/memory/models/service_models.py +++ b/api/app/core/memory/models/service_models.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict +from pydantic import BaseModel, Field, field_serializer, ConfigDict, model_validator, computed_field from app.core.memory.enums import Neo4jNodeType, StorageType +from app.core.validators import file_validator from app.schemas.memory_config_schema import MemoryConfig @@ -24,3 +25,17 @@ class Memory(BaseModel): @field_serializer("source") def serialize_source(self, v) -> str: return v.value + + +class MemorySearchResult(BaseModel): + memories: list[Memory] + + @computed_field + @property + def content(self) -> str: + return "\n".join([memory.content for memory in self.memories]) + + @computed_field + @property + def count(self) -> int: + return len(self.memories) diff --git a/api/app/core/memory/pipelines/base_pipeline.py b/api/app/core/memory/pipelines/base_pipeline.py index 322f6787..60c48b9d 100644 --- a/api/app/core/memory/pipelines/base_pipeline.py +++ b/api/app/core/memory/pipelines/base_pipeline.py @@ -4,22 +4,32 @@ from typing import Any from sqlalchemy.orm import Session -from app.core.memory.llm_tools import OpenAIEmbedderClient from app.core.memory.models.service_models import MemoryContext -from app.core.models import RedBearModelConfig +from app.core.models import RedBearModelConfig, RedBearLLM, RedBearEmbeddings from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelApiKeyService class ModelClientMixin(ABC): @staticmethod - def get_llm_client(db: Session, model_id: uuid.UUID): - pass + def get_llm_client(db: Session, model_id: uuid.UUID) -> RedBearLLM: + api_config = ModelApiKeyService.get_available_api_key(db, model_id) + return RedBearLLM( + RedBearModelConfig( + model_name=api_config.model_name, + provider=api_config.provider, + api_key=api_config.api_key, + base_url=api_config.api_base, + is_omni=api_config.is_omni, + support_thinking="thinking" in (api_config.capability or []), + ) + ) @staticmethod - def get_embedding_client(db: Session, model_id: uuid.UUID) -> OpenAIEmbedderClient: + def get_embedding_client(db: Session, model_id: uuid.UUID) -> RedBearEmbeddings: config_service = MemoryConfigService(db) embedder_client_config = config_service.get_embedder_config(str(model_id)) - return OpenAIEmbedderClient( + return RedBearEmbeddings( RedBearModelConfig( model_name=embedder_client_config["model_name"], provider=embedder_client_config["provider"], @@ -30,10 +40,15 @@ class ModelClientMixin(ABC): class BasePipeline(ABC): - def __init__(self, ctx: MemoryContext, db: Session): + def __init__(self, ctx: MemoryContext): self.ctx = ctx - self.db = db @abstractmethod async def run(self, *args, **kwargs) -> Any: pass + + +class DBRequiredPipeline(BasePipeline, ABC): + def __init__(self, ctx: MemoryContext, db: Session): + super().__init__(ctx) + self.db = db diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py index 5f5a1a1f..83662d90 100644 --- a/api/app/core/memory/pipelines/memory_read.py +++ b/api/app/core/memory/pipelines/memory_read.py @@ -1,31 +1,41 @@ -from app.core.memory.enums import SearchStrategy -from app.core.memory.pipelines.base_pipeline import BasePipeline, ModelClientMixin -from app.core.memory.read_services.content_search import Neo4jSearchService +from app.core.memory.enums import SearchStrategy, StorageType +from app.core.memory.models.service_models import MemorySearchResult +from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline +from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService from app.core.memory.read_services.query_preprocessor import QueryPreprocessor -class ReadPipeLine(ModelClientMixin, BasePipeline): - async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10): +class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): + async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10, includes=None) -> MemorySearchResult: query = QueryPreprocessor.process(query) + if self.ctx.storage_type == StorageType.RAG: + return await self._rag_read(query, limit) match search_switch: case SearchStrategy.DEEP: - return await self._deep_read() + return await self._deep_read(query, limit, includes) case SearchStrategy.NORMAL: - return await self._normal_read(query) + return await self._normal_read(query, limit, includes) case SearchStrategy.QUICK: - return await self._quick_read(query, limit) + return await self._quick_read(query, limit, includes) case _: raise RuntimeError("Unsupported search strategy") - async def _deep_read(self): + async def _rag_read(self, query: str, limit: int) -> MemorySearchResult: + service = RAGSearchService( + self.ctx + ) + return await service.search() + + async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: pass - async def _normal_read(self, query): + async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: pass - async def _quick_read(self, query, limit): + async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: search_service = Neo4jSearchService( self.ctx, - self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id) + self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id), + includes=includes, ) return await search_service.search(query, limit) diff --git a/api/app/core/memory/prompt/__init__.py b/api/app/core/memory/prompt/__init__.py new file mode 100644 index 00000000..299470f8 --- /dev/null +++ b/api/app/core/memory/prompt/__init__.py @@ -0,0 +1,85 @@ +import logging +import threading +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader, TemplateNotFound, TemplateSyntaxError + +logger = logging.getLogger(__name__) + +PROMPT_DIR = Path(__file__).parent + + +class PromptRenderError(Exception): + def __init__(self, template_name: str, error: Exception): + self.template_name = template_name + self.error = error + super().__init__(f"Failed to render prompt '{template_name}': {error}") + + +class PromptManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._init_once() + return cls._instance + + def _init_once(self): + self.env = Environment( + loader=FileSystemLoader(str(PROMPT_DIR)), + autoescape=False, + keep_trailing_newline=True, + ) + logger.info(f"PromptManager initialized: template_dir={PROMPT_DIR}") + + def __repr__(self): + templates = self.list_templates() + return f"" + + def list_templates(self) -> list[str]: + return [ + Path(name).stem + for name in self.env.loader.list_templates() + if name.endswith('.jinja2') + ] + + def get(self, name: str) -> str: + template_name = self._resolve_name(name) + try: + source, _, _ = self.env.loader.get_source(self.env, template_name) + return source + except TemplateNotFound: + raise FileNotFoundError( + f"Prompt '{name}' not found. " + f"Available: {self.list_templates()}" + ) + + def render(self, name: str, **kwargs) -> str: + template_name = self._resolve_name(name) + try: + template = self.env.get_template(template_name) + return template.render(**kwargs) + except TemplateNotFound: + raise FileNotFoundError( + f"Prompt '{name}' not found. " + f"Available: {self.list_templates()}" + ) + except TemplateSyntaxError as e: + logger.error(f"Prompt syntax error in '{name}': {e}", exc_info=True) + raise PromptRenderError(name, e) + except Exception as e: + logger.error(f"Prompt render failed for '{name}': {e}", exc_info=True) + raise PromptRenderError(name, e) + + @staticmethod + def _resolve_name(name: str) -> str: + if not name.endswith('.jinja2'): + return f"{name}.jinja2" + return name + + +prompt_manager = PromptManager() diff --git a/api/app/core/memory/prompt/problem_split.jinja2 b/api/app/core/memory/prompt/problem_split.jinja2 new file mode 100644 index 00000000..ff134ddb --- /dev/null +++ b/api/app/core/memory/prompt/problem_split.jinja2 @@ -0,0 +1,212 @@ + +# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#} +你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: +## 目标: +你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。 +--- + +### 历史信息参考 +在生成扩展问题时,你可以参考以下历史数据(如果提供): +- 历史对话或任务的主题; +- 历史中出现的关键实体(时间、人物、地点、研究主题等); +- 历史中已解答的问题(避免重复); +- 历史推理链(保持逻辑一致性)。 + +> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 +输入历史信息内容:{{history}} + +## User Input +{{ sentence }} + +## 需求: +1:首先判断类型(单跳、多跳、开放域、时间)。 +2:根据类型进行拆分。 +3:拆分后的内容需保证信息完整且可独立处理。 +4:对每个拆分条目,可附加示例或说明。 +5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 + 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] + 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? + +## 指代消歧规则(Coreference Resolution): +在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化: + +1. **"用户"的消歧**: + - "用户是谁?" → 分析历史记录,找出对话发起者的姓名 + - 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人 + - 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?" + +2. **"我"的消歧**: + - "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?" + - 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?" + +3. **"他/她/它"的消歧**: + - 从上下文或历史中找出最近提到的同类实体 + - 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?" + +4. **"那个人/这个人"的消歧**: + - 从历史中找出最近提到的人物 + - 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?" + +5. **优先级**: + - 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人 + - 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象" + +## 指令: +你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: +单跳(Single-hop) + 描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。 + 拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。 + 示例: + 输入数据:"请列出今年诺贝尔物理学奖的得主" + 拆分结果:[ + { + "id": "Q1", + "question": "今年诺贝尔物理学奖得主是谁", + "type": "单跳’", + } + ] + 注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题 +多跳(Multi-hop): + 描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。 + 拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。 + 示例: + 输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果" + 拆分结果: + [ + { + "id": "Q1", + "question": 今年诺贝尔物理学奖得主是谁?", + "type": "多跳’", + }, + { + "id": "Q2", + "question": "该得主的研究领域是什么?", + "type": "多跳’", + }, + { + "id": "Q3", + "question": "该得主的代表性成果有哪些?", + "type": "多跳’" + } + ] +开放域(Open-domain): + 描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。 + 拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性) + 需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。 + 示例: + 输入数据:"介绍量子计算的最新研究进展" + 拆分结果: + [ + { + "id": "Q1", + "question": 量子计算的基本概念是什么?", + "type": "开放域’", + }, + { + "id": "Q2", + "question": "当前量子计算的主要研究方向有哪些?", + "type": "开放域’", + }, + { + "id": "Q3", + "question": "近期在量子计算领域有哪些重大进展?", + "type": "开放域’", + } + ] + +时间(Temporal): + 描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。 + 拆分策略:根据事件时间或时间段拆分为独立条目或问题。 + 示例: + 输入数据:"列出苹果公司过去五年的重大事件" + 拆分结果: + [ + { + "id": "Q1", + "question": 苹果公司2019年的重大事件有哪些?", + "type": "时间’", + }, + { + "id": "Q2", + "question": "苹果公司2020年的重大事件有哪些?", + "type": "时间’", + }, + { + "id": "Q3", + "question": "苹果公司2021年的重大事件有哪些?", + "type": "时间’", + }, + { + "id": "Q3", + "question": "苹果公司2022年的重大事件有哪些?", + "type": "时间’", + } + , + { + "id": "Q4", + "question": "苹果公司2023年的重大事件有哪些?", + "type": "时间’", + } + ] + +输出要求: +- 每个子问题包括: + - `id`: 子问题编号(Q1, Q2...) + - `question`: 子问题内容 + - `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等) + - `reason`: 拆分的理由(为什么要这样拆) +- 格式案例: +[ + { + "id": "Q1", + "question": 量子计算的基本概念是什么?", + "type": "开放域’", + }, + { + "id": "Q2", + "question": "当前量子计算的主要研究方向有哪些?", + "type": "开放域’", + }, + { + "id": "Q3", + "question": "近期在量子计算领域有哪些重大进展?", + "type": "开放域’", + } +] +- 必须通过json.loads()的格式支持的形式输出 +- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 + +## 指代消歧示例(重要): +示例1 - "用户"的消歧: +输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}] +输入问题:"用户是谁?" +输出: +[ + { + "id": "Q1", + "question": "李建国是谁?", + "type": "单跳", + "reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国" + } +] + +示例2 - "我"的消歧: +输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}] +输入问题:"我推荐的书是什么?" +输出: +[ + { + "id": "Q1", + "question": "张曼玉推荐的书是什么?", + "type": "单跳", + "reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉" + } +] + +- 关键的JSON格式要求 +1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号 +2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 +3.确保所有JSON字符串都正确关闭并以逗号分隔 +4.JSON字符串值中不包括换行符 +5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\"" +6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` diff --git a/api/app/core/memory/read_services/content_search.py b/api/app/core/memory/read_services/content_search.py index 69ca6b11..58356e84 100644 --- a/api/app/core/memory/read_services/content_search.py +++ b/api/app/core/memory/read_services/content_search.py @@ -1,37 +1,28 @@ import asyncio import logging import math -import time - -from pydantic import BaseModel, Field from app.core.memory.enums import Neo4jNodeType -from app.core.memory.llm_tools import OpenAIEmbedderClient from app.core.memory.memory_service import MemoryContext -from app.core.memory.models.service_models import Memory +from app.core.memory.models.service_models import Memory, MemorySearchResult from app.core.memory.read_services.result_builder import data_builder_factory +from app.core.models import RedBearEmbeddings from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector logger = logging.getLogger(__name__) - -class MemorySearchResult(BaseModel): - memories: dict[str, list[dict]] = Field(default_factory=dict) - content: str = Field(default="") - count: int = Field(default=0) +DEFAULT_ALPHA = 0.7 +DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1 +DEFAULT_COSINE_SCORE_THRESHOLD = 0.5 +DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5 class Neo4jSearchService: - DEFAULT_ALPHA = 0.6 - DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1 - DEFAULT_COSINE_SCORE_THRESHOLD = 0.5 - DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5 - def __init__( self, ctx: MemoryContext, - embedder: OpenAIEmbedderClient, + embedder: RedBearEmbeddings, includes: list[Neo4jNodeType] | None = None, alpha: float = DEFAULT_ALPHA, fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD, @@ -44,7 +35,7 @@ class Neo4jSearchService: self.cosine_score_threshold = cosine_score_threshold self.content_score_threshold = content_score_threshold - self.embedder: OpenAIEmbedderClient = embedder + self.embedder: RedBearEmbeddings = embedder self.connector: Neo4jConnector | None = None self.includes = includes @@ -121,9 +112,12 @@ class Neo4jSearchService: kw = float(combined[item_id].get("kw_score", 0) or 0) emb = float(combined[item_id].get("embedding_score", 0) or 0) base = self.alpha * emb + (1 - self.alpha) * kw - combined[item_id]["content_score"] = base + min(1 - base, 0.1 * kw * emb) + combined[item_id]["content_score"] = base + min(1 - base, kw * emb) results = sorted(combined.values(), key=lambda x: x["content_score"], reverse=True) - # results = [res for res in results if res["content_score"] > self.content_score_threshold] + # results = [ + # res for res in results + # if res["content_score"] > self.content_score_threshold + # ] results = results[:limit] logger.info( @@ -137,14 +131,14 @@ class Neo4jSearchService: return items scores = [float(it.get("score", 0) or 0) for it in items] for it, s in zip(items, scores): - it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2)) + it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2)) if s else 0 return items async def search( self, query: str, limit: int = 10, - ) -> list[Memory]: + ) -> MemorySearchResult: async with Neo4jConnector() as connector: self.connector = connector kw_task = self._keyword_search(query, limit) @@ -175,4 +169,12 @@ class Neo4jSearchService: query=query )) memories.sort(key=lambda x: x.score, reverse=True) - return memories[:limit] + return MemorySearchResult(memories=memories[:limit]) + + +class RAGSearchService: + def __init__(self, ctx: MemoryContext): + pass + + async def search(self) -> MemorySearchResult: + pass diff --git a/api/app/core/memory/read_services/memory_search/__init__.py b/api/app/core/memory/read_services/memory_search/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/app/core/memory/read_services/memory_search/content_search.py b/api/app/core/memory/read_services/memory_search/content_search.py deleted file mode 100644 index f5e58696..00000000 --- a/api/app/core/memory/read_services/memory_search/content_search.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: UTF-8 -*- -# Author: Eternity -# @Email: 1533512157@qq.com -# @Time : 2026/4/9 16:48 -from app.core.memory.llm_tools import OpenAIEmbedderClient -from app.core.memory.memory_service import MemoryContext - - -class ContentSearch: - def __init__(self, ctx: MemoryContext): - self.ctx = ctx - - async def search(self, query): - pass \ No newline at end of file diff --git a/api/app/core/memory/read_services/memory_search/perceptual_search.py b/api/app/core/memory/read_services/memory_search/perceptual_search.py deleted file mode 100644 index db81e2f8..00000000 --- a/api/app/core/memory/read_services/memory_search/perceptual_search.py +++ /dev/null @@ -1,228 +0,0 @@ -import asyncio -import logging -from typing import Any - -from pydantic import BaseModel - -from app.core.memory.llm_tools import OpenAIEmbedderClient -from app.core.memory.memory_service import MemoryContext -from app.core.memory.utils.data import escape_lucene_query -from app.repositories.neo4j.graph_search import search_perceptual, search_perceptual_by_embedding -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -logger = logging.getLogger(__name__) - - -class PerceptualResult(BaseModel): - memories: list[dict[str, Any]] = [] - content: str = "" - keyword_raw: int = 0 - embedding_raw: int = 0 - - -class PerceptualRetrieverService: - DEFAULT_ALPHA = 0.6 - DEFAULT_FULLTEXT_SCORE_THRESHOLD = 0.5 - DEFAULT_COSINE_SCORE_THRESHOLD = 0.7 - - def __init__( - self, - ctx: MemoryContext, - embedder: OpenAIEmbedderClient, - alpha: float = DEFAULT_ALPHA, - fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD, - cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD - ): - self.ctx = ctx - self.alpha = alpha - self.fulltext_score_threshold = fulltext_score_threshold - self.cosine_score_threshold = cosine_score_threshold - - self.embedder: OpenAIEmbedderClient = embedder - self.connector = Neo4jConnector() - - async def search( - self, - query: str, - keywords: list[str] | None = None, - limit: int = 10 - ) -> PerceptualResult: - if keywords is None: - keywords = [query] if query else [] - - try: - kw_task = self._keyword_search(keywords, limit) - emb_task = self._embedding_search(query, limit) - kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True) - if isinstance(kw_results, Exception): - logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}") - kw_results = [] - if isinstance(emb_results, Exception): - logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}") - emb_results = [] - - reranked = self._rerank(kw_results, emb_results, limit) - - memories = [] - content_parts = [] - for record in reranked: - fmt = self._format_result(record) - fmt["score"] = round(record.get("content_score", 0), 4) - memories.append(fmt) - content_parts.append(self._build_content_text(fmt)) - - logger.info( - f"[PerceptualSearch] {len(memories)} results after rerank " - f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})" - ) - return PerceptualResult( - memories=memories, - content="\n\n".join(content_parts), - keyword_raw=len(kw_results), - embedding_raw=len(emb_results), - ) - except Exception as e: - logger.error(f"[PerceptualSearch] search failed: {e}", exc_info=True) - return PerceptualResult() - finally: - await self.connector.close() - - async def _keyword_search( - self, - keywords: list[str], - limit: int - ) -> list[dict]: - seen_ids: set = set() - all_results: list[dict] = [] - - async def _one(kw: str): - escaped = escape_lucene_query(kw) - if not escaped.strip(): - return [] - r = await search_perceptual( - connector=self.connector, q=escaped, - end_user_id=self.ctx.end_user_id, limit=limit - ) - perceptuals = r.get("perceptuals", []) - return [perceptual for perceptual in perceptuals if perceptual["score"] > self.fulltext_score_threshold] - - tasks = [_one(kw) for kw in keywords] - batch = await asyncio.gather(*tasks, return_exceptions=True) - - for result in batch: - if isinstance(result, Exception): - logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}") - continue - for rec in result: - rid = rec.get("id", "") - if rid and rid not in seen_ids: - seen_ids.add(rid) - all_results.append(rec) - all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True) - return all_results[:limit] - - async def _embedding_search( - self, - query: str, - limit: int - ) -> list[dict]: - r = await search_perceptual_by_embedding( - connector=self.connector, - embedder_client=self.embedder, - query_text=query, - end_user_id=self.ctx.end_user_id, - limit=limit - ) - perceptuals = r.get("perceptuals", []) - return [perceptual for perceptual in perceptuals if perceptual["score"] > self.cosine_score_threshold] - - def _rerank( - self, - keyword_results: list[dict], - embedding_results: list[dict], - limit: int, - ) -> list[dict]: - keyword_results = self._normalize_scores(keyword_results) - embedding_results = self._normalize_scores(embedding_results) - - kw_norm_map = {} - for item in keyword_results: - item_id = item["id"] - kw_norm_map[item_id] = float(item.get("normalized_score", 0)) - - emb_norm_map = {} - for item in embedding_results: - item_id = item["id"] - emb_norm_map[item_id] = float(item.get("normalized_score", 0)) - - combined = {} - for item in keyword_results: - item_id = item["id"] - combined[item_id] = item.copy() - combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0) - combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0) - - for item in embedding_results: - item_id = item["id"] - if item_id in combined: - combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0) - else: - combined[item_id] = item.copy() - combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0) - combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0) - - for item in combined.values(): - kw = float(item.get("kw_score", 0) or 0) - emb = float(item.get("embedding_score", 0) or 0) - item["content_score"] = self.alpha * emb + (1 - self.alpha) * kw - - results = list(combined.values()) - results.sort(key=lambda x: x["content_score"], reverse=True) - results = results[:limit] - - logger.info( - f"[PerceptualSearch] rerank: merged={len(combined)}, after_threshold={len(results)} " - f"(alpha={self.alpha})" - ) - return results - - @staticmethod - def _normalize_scores(items: list[dict], field: str = "score") -> list[dict]: - """Min-max 归一化,将分数线性映射到 [0, 1]。""" - if not items: - return items - scores = [float(it.get(field, 0) or 0) for it in items] - min_s = min(scores) - max_s = max(scores) - diff = max_s - min_s - for it, s in zip(items, scores): - it[f"normalized_{field}"] = (s - min_s) / diff if diff > 0 else 1.0 - return items - - @staticmethod - def _format_result(record: dict) -> dict: - return { - "id": record.get("id", ""), - "perceptual_type": record.get("perceptual_type", ""), - "file_name": record.get("file_name", ""), - "file_path": record.get("file_path", ""), - "summary": record.get("summary", ""), - "topic": record.get("topic", ""), - "domain": record.get("domain", ""), - "keywords": record.get("keywords", []), - "created_at": str(record.get("created_at", "")), - "file_type": record.get("file_type", ""), - "score": record.get("score", 0), - } - - @staticmethod - def _build_content_text(formatted: dict) -> str: - content_text = (f"" - f"{formatted["file_name"]}" - f"{formatted["file_path"]}" - f"{formatted["file_type"]}" - f"{formatted["topic"]}" - f"{formatted["keywords"]}" - f"{formatted["summary"]}" - f"") - return content_text diff --git a/api/app/core/memory/read_services/query_preprocessor.py b/api/app/core/memory/read_services/query_preprocessor.py index 02d757c9..123cae40 100644 --- a/api/app/core/memory/read_services/query_preprocessor.py +++ b/api/app/core/memory/read_services/query_preprocessor.py @@ -1,11 +1,13 @@ -# -*- coding: UTF-8 -*- -# Author: Eternity -# @Email: 1533512157@qq.com -# @Time : 2026/4/8 18:11 +import logging import re +from app.core.memory.prompt import prompt_manager +from app.core.memory.utils.llm.llm_utils import StructResponse +from app.core.models import RedBearLLM from app.schemas.memory_agent_schema import AgentMemoryDataset +logger = logging.getLogger(__name__) + class QueryPreprocessor: @staticmethod @@ -16,3 +18,22 @@ class QueryPreprocessor: text = re.sub(rf"{"|".join(AgentMemoryDataset.PRONOUN)}", AgentMemoryDataset.NAME, text) return text + + @staticmethod + async def split(query: str, llm_client: RedBearLLM): + system_prompt = prompt_manager.render( + name="problem_split", + history=[], + sentence=query, + ) + messages = [{"role": "system", "content": system_prompt}] + try: + sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json') + except Exception as e: + logger.error(f"[QueryPreprocessor] Sub-question segmentation failed - {e}") + sub_queries = None + return sub_queries or query + + @staticmethod + async def extension(query: str, llm_client: RedBearLLM): + pass diff --git a/api/app/core/memory/read_services/result_builder.py b/api/app/core/memory/read_services/result_builder.py index 10ff8c86..949ff3ed 100644 --- a/api/app/core/memory/read_services/result_builder.py +++ b/api/app/core/memory/read_services/result_builder.py @@ -114,7 +114,7 @@ class PerceptualBuilder(BaseBuilder): f"{self.record.get('domain')}" f"{self.record.get('keywords')}" f"{self.record.get('file_type')}" - "") + "") class CommunityBuilder(BaseBuilder): diff --git a/api/app/core/memory/read_services/retrieval_summary.py b/api/app/core/memory/read_services/retrieval_summary.py new file mode 100644 index 00000000..6b166cf2 --- /dev/null +++ b/api/app/core/memory/read_services/retrieval_summary.py @@ -0,0 +1,11 @@ +from app.core.models import RedBearLLM + + +class RetrievalSummaryProcessor: + @staticmethod + def summary(content: str, llm_client: RedBearLLM): + return + + @staticmethod + def verify(content: str, llm_client: RedBearLLM): + return \ No newline at end of file diff --git a/api/app/core/memory/utils/llm/llm_utils.py b/api/app/core/memory/utils/llm/llm_utils.py index 19d76d68..c4eee82f 100644 --- a/api/app/core/memory/utils/llm/llm_utils.py +++ b/api/app/core/memory/utils/llm/llm_utils.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, Type + +from json_repair import json_repair +from langchain_core.messages import AIMessage from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig @@ -13,6 +16,27 @@ async def handle_response(response: type[BaseModel]) -> dict: return response.model_dump() +class StructResponse: + def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None): + self.mode = mode + if mode == "pydantic" and model is None: + raise ValueError("Pydantic model is required") + + self.model = model + + def __ror__(self, other: AIMessage): + if not isinstance(other, AIMessage): + raise RuntimeError(f"Unsupported struct type {type(other)}") + text = '' + for block in other.content_blocks: + if block.get("type") == "text": + text += block.get("text", "") + fixed_json = json_repair.repair_json(text, return_objects=True) + if self.mode == "json": + return fixed_json + return self.model.model_validate(fixed_json) + + class MemoryClientFactory: """ Factory for creating LLM, embedder, and reranker clients. @@ -24,21 +48,21 @@ class MemoryClientFactory: >>> llm_client = factory.get_llm_client(model_id) >>> embedder_client = factory.get_embedder_client(embedding_id) """ - + def __init__(self, db: Session): from app.services.memory_config_service import MemoryConfigService self._config_service = MemoryConfigService(db) - + def get_llm_client(self, llm_id: str) -> OpenAIClient: """Get LLM client by model ID.""" if not llm_id: raise ValueError("LLM ID is required") - + try: model_config = self._config_service.get_model_config(llm_id) except Exception as e: raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e - + try: return OpenAIClient( RedBearModelConfig( @@ -52,19 +76,19 @@ class MemoryClientFactory: except Exception as e: model_name = model_config.get('model_name', 'unknown') raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e - + def get_embedder_client(self, embedding_id: str): """Get embedder client by model ID.""" from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - + if not embedding_id: raise ValueError("Embedding ID is required") - + try: embedder_config = self._config_service.get_embedder_config(embedding_id) except Exception as e: raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e - + try: return OpenAIEmbedderClient( RedBearModelConfig( @@ -77,17 +101,17 @@ class MemoryClientFactory: except Exception as e: model_name = embedder_config.get('model_name', 'unknown') raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e - + def get_reranker_client(self, rerank_id: str) -> OpenAIClient: """Get reranker client by model ID.""" if not rerank_id: raise ValueError("Rerank ID is required") - + try: model_config = self._config_service.get_model_config(rerank_id) except Exception as e: raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e - + try: return OpenAIClient( RedBearModelConfig( diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 336f4134..354c0e23 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -8,6 +8,7 @@ import numpy as np from app.core.memory.enums import Neo4jNodeType from app.core.memory.llm_tools import OpenAIEmbedderClient from app.core.memory.utils.data.text_utils import escape_lucene_query +from app.core.models import RedBearEmbeddings from app.repositories.neo4j.cypher_queries import ( EXPAND_COMMUNITY_STATEMENTS, SEARCH_CHUNK_BY_CHUNK_ID, @@ -358,7 +359,7 @@ async def search_by_embedding( USER_ID_QUERY_CYPHER_MAPPING[node_type], end_user_id=end_user_id, ) - records = [record for record in records if record if record["embedding"] is not None] + records = [record for record in records if record and record.get("embedding") is not None] ids = [item['id'] for item in records] vectors = [item['embedding'] for item in records] sim_res = cosine_similarity_search(query_embedding, vectors, limit=limit) @@ -469,7 +470,7 @@ async def search_graph( async def search_graph_by_embedding( connector: Neo4jConnector, - embedder_client, + embedder_client: RedBearEmbeddings | OpenAIEmbedderClient, query_text: str, end_user_id: str, limit: int = 50, @@ -495,7 +496,10 @@ async def search_graph_by_embedding( Neo4jNodeType.PERCEPTUAL ] - embeddings = await embedder_client.response([query_text]) + if isinstance(embedder_client, RedBearEmbeddings): + embeddings = embedder_client.embed_documents([query_text]) + else: + embeddings = await embedder_client.response([query_text]) if not embeddings or not embeddings[0]: logger.warning(f"search_graph_by_embedding: embedding generation failed for '{query_text[:50]}'") return {search_key: [] for search_key in include} diff --git a/api/app/services/prompt/prompt_optimizer_system.jinja2 b/api/app/services/prompt/prompt_optimizer_system.jinja2 index 39a4ba68..5611ae94 100644 --- a/api/app/services/prompt/prompt_optimizer_system.jinja2 +++ b/api/app/services/prompt/prompt_optimizer_system.jinja2 @@ -34,7 +34,7 @@ Readability Guideline: Ensure optimized prompts have good readability and logica Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %} Constraints -Output Constraint: Must output in JSON format including the fields "prompt" and "desc". +Output Constraint: Must output in JSON format including the string fields "prompt" and "desc". Content Constraint: Must not include any explanations, analyses, or additional comments. Language Constraint: Must use clear and concise language. {% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}