refactor(memory): consolidate memory search services and update model client handling
- Consolidate memory search services by removing separate content_search.py and perceptual_search.py - Update model client handling in base_pipeline.py to use ModelApiKeyService for LLM client initialization - Add new prompt files and modify existing services to support consolidated search architecture - Refactor memory read pipeline and related services to use updated model client approach
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.memory.enums import StorageType, SearchStrategy
|
||||
from app.core.memory.models.service_models import Memory, MemoryContext
|
||||
from app.core.memory.models.service_models import MemoryContext, MemorySearchResult
|
||||
from app.core.memory.pipelines.memory_read import ReadPipeLine
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -35,9 +35,14 @@ class MemoryService:
|
||||
async def write(self, messages: list[dict]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def read(self, query: str, history: list, search_switch: SearchStrategy) -> list[Memory]:
|
||||
async def read(
|
||||
self,
|
||||
query: str,
|
||||
search_switch: SearchStrategy,
|
||||
limit: int = 10,
|
||||
) -> MemorySearchResult:
|
||||
with get_db_context() as db:
|
||||
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit=10)
|
||||
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
|
||||
|
||||
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict, model_validator, computed_field
|
||||
|
||||
from app.core.memory.enums import Neo4jNodeType, StorageType
|
||||
from app.core.validators import file_validator
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
@@ -24,3 +25,17 @@ class Memory(BaseModel):
|
||||
@field_serializer("source")
|
||||
def serialize_source(self, v) -> str:
|
||||
return v.value
|
||||
|
||||
|
||||
class MemorySearchResult(BaseModel):
|
||||
memories: list[Memory]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return "\n".join([memory.content for memory in self.memories])
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.memories)
|
||||
|
||||
@@ -4,22 +4,32 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.models.service_models import MemoryContext
|
||||
from app.core.models import RedBearModelConfig
|
||||
from app.core.models import RedBearModelConfig, RedBearLLM, RedBearEmbeddings
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
|
||||
class ModelClientMixin(ABC):
|
||||
@staticmethod
|
||||
def get_llm_client(db: Session, model_id: uuid.UUID):
|
||||
pass
|
||||
def get_llm_client(db: Session, model_id: uuid.UUID) -> RedBearLLM:
|
||||
api_config = ModelApiKeyService.get_available_api_key(db, model_id)
|
||||
return RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base,
|
||||
is_omni=api_config.is_omni,
|
||||
support_thinking="thinking" in (api_config.capability or []),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_client(db: Session, model_id: uuid.UUID) -> OpenAIEmbedderClient:
|
||||
def get_embedding_client(db: Session, model_id: uuid.UUID) -> RedBearEmbeddings:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_client_config = config_service.get_embedder_config(str(model_id))
|
||||
return OpenAIEmbedderClient(
|
||||
return RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=embedder_client_config["model_name"],
|
||||
provider=embedder_client_config["provider"],
|
||||
@@ -30,10 +40,15 @@ class ModelClientMixin(ABC):
|
||||
|
||||
|
||||
class BasePipeline(ABC):
|
||||
def __init__(self, ctx: MemoryContext, db: Session):
|
||||
def __init__(self, ctx: MemoryContext):
|
||||
self.ctx = ctx
|
||||
self.db = db
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, *args, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class DBRequiredPipeline(BasePipeline, ABC):
|
||||
def __init__(self, ctx: MemoryContext, db: Session):
|
||||
super().__init__(ctx)
|
||||
self.db = db
|
||||
|
||||
@@ -1,31 +1,41 @@
|
||||
from app.core.memory.enums import SearchStrategy
|
||||
from app.core.memory.pipelines.base_pipeline import BasePipeline, ModelClientMixin
|
||||
from app.core.memory.read_services.content_search import Neo4jSearchService
|
||||
from app.core.memory.enums import SearchStrategy, StorageType
|
||||
from app.core.memory.models.service_models import MemorySearchResult
|
||||
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
||||
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
|
||||
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
||||
|
||||
|
||||
class ReadPipeLine(ModelClientMixin, BasePipeline):
|
||||
async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10):
|
||||
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10, includes=None) -> MemorySearchResult:
|
||||
query = QueryPreprocessor.process(query)
|
||||
if self.ctx.storage_type == StorageType.RAG:
|
||||
return await self._rag_read(query, limit)
|
||||
match search_switch:
|
||||
case SearchStrategy.DEEP:
|
||||
return await self._deep_read()
|
||||
return await self._deep_read(query, limit, includes)
|
||||
case SearchStrategy.NORMAL:
|
||||
return await self._normal_read(query)
|
||||
return await self._normal_read(query, limit, includes)
|
||||
case SearchStrategy.QUICK:
|
||||
return await self._quick_read(query, limit)
|
||||
return await self._quick_read(query, limit, includes)
|
||||
case _:
|
||||
raise RuntimeError("Unsupported search strategy")
|
||||
|
||||
async def _deep_read(self):
|
||||
async def _rag_read(self, query: str, limit: int) -> MemorySearchResult:
|
||||
service = RAGSearchService(
|
||||
self.ctx
|
||||
)
|
||||
return await service.search()
|
||||
|
||||
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
pass
|
||||
|
||||
async def _normal_read(self, query):
|
||||
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
pass
|
||||
|
||||
async def _quick_read(self, query, limit):
|
||||
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
search_service = Neo4jSearchService(
|
||||
self.ctx,
|
||||
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id)
|
||||
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id),
|
||||
includes=includes,
|
||||
)
|
||||
return await search_service.search(query, limit)
|
||||
|
||||
85
api/app/core/memory/prompt/__init__.py
Normal file
85
api/app/core/memory/prompt/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, TemplateSyntaxError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class PromptRenderError(Exception):
|
||||
def __init__(self, template_name: str, error: Exception):
|
||||
self.template_name = template_name
|
||||
self.error = error
|
||||
super().__init__(f"Failed to render prompt '{template_name}': {error}")
|
||||
|
||||
|
||||
class PromptManager:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._init_once()
|
||||
return cls._instance
|
||||
|
||||
def _init_once(self):
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(PROMPT_DIR)),
|
||||
autoescape=False,
|
||||
keep_trailing_newline=True,
|
||||
)
|
||||
logger.info(f"PromptManager initialized: template_dir={PROMPT_DIR}")
|
||||
|
||||
def __repr__(self):
|
||||
templates = self.list_templates()
|
||||
return f"<PromptManager: {len(templates)} prompts: {templates}>"
|
||||
|
||||
def list_templates(self) -> list[str]:
|
||||
return [
|
||||
Path(name).stem
|
||||
for name in self.env.loader.list_templates()
|
||||
if name.endswith('.jinja2')
|
||||
]
|
||||
|
||||
def get(self, name: str) -> str:
|
||||
template_name = self._resolve_name(name)
|
||||
try:
|
||||
source, _, _ = self.env.loader.get_source(self.env, template_name)
|
||||
return source
|
||||
except TemplateNotFound:
|
||||
raise FileNotFoundError(
|
||||
f"Prompt '{name}' not found. "
|
||||
f"Available: {self.list_templates()}"
|
||||
)
|
||||
|
||||
def render(self, name: str, **kwargs) -> str:
|
||||
template_name = self._resolve_name(name)
|
||||
try:
|
||||
template = self.env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
except TemplateNotFound:
|
||||
raise FileNotFoundError(
|
||||
f"Prompt '{name}' not found. "
|
||||
f"Available: {self.list_templates()}"
|
||||
)
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"Prompt syntax error in '{name}': {e}", exc_info=True)
|
||||
raise PromptRenderError(name, e)
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt render failed for '{name}': {e}", exc_info=True)
|
||||
raise PromptRenderError(name, e)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_name(name: str) -> str:
|
||||
if not name.endswith('.jinja2'):
|
||||
return f"{name}.jinja2"
|
||||
return name
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
212
api/app/core/memory/prompt/problem_split.jinja2
Normal file
212
api/app/core/memory/prompt/problem_split.jinja2
Normal file
@@ -0,0 +1,212 @@
|
||||
|
||||
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
## 目标:
|
||||
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{{ sentence }}
|
||||
|
||||
## 需求:
|
||||
1:首先判断类型(单跳、多跳、开放域、时间)。
|
||||
2:根据类型进行拆分。
|
||||
3:拆分后的内容需保证信息完整且可独立处理。
|
||||
4:对每个拆分条目,可附加示例或说明。
|
||||
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
|
||||
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
|
||||
示例:
|
||||
输入数据:"请列出今年诺贝尔物理学奖的得主"
|
||||
拆分结果:[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "今年诺贝尔物理学奖得主是谁",
|
||||
"type": "单跳’",
|
||||
}
|
||||
]
|
||||
注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题
|
||||
多跳(Multi-hop):
|
||||
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
|
||||
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
|
||||
示例:
|
||||
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 今年诺贝尔物理学奖得主是谁?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "该得主的研究领域是什么?",
|
||||
"type": "多跳’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "该得主的代表性成果有哪些?",
|
||||
"type": "多跳’"
|
||||
}
|
||||
]
|
||||
开放域(Open-domain):
|
||||
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
|
||||
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
|
||||
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
|
||||
示例:
|
||||
输入数据:"介绍量子计算的最新研究进展"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
|
||||
时间(Temporal):
|
||||
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
|
||||
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
|
||||
示例:
|
||||
输入数据:"列出苹果公司过去五年的重大事件"
|
||||
拆分结果:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 苹果公司2019年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "苹果公司2020年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2021年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "苹果公司2022年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
,
|
||||
{
|
||||
"id": "Q4",
|
||||
"question": "苹果公司2023年的重大事件有哪些?",
|
||||
"type": "时间’",
|
||||
}
|
||||
]
|
||||
|
||||
输出要求:
|
||||
- 每个子问题包括:
|
||||
- `id`: 子问题编号(Q1, Q2...)
|
||||
- `question`: 子问题内容
|
||||
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
|
||||
- `reason`: 拆分的理由(为什么要这样拆)
|
||||
- 格式案例:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": 量子计算的基本概念是什么?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q2",
|
||||
"question": "当前量子计算的主要研究方向有哪些?",
|
||||
"type": "开放域’",
|
||||
},
|
||||
{
|
||||
"id": "Q3",
|
||||
"question": "近期在量子计算领域有哪些重大进展?",
|
||||
"type": "开放域’",
|
||||
}
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
@@ -1,37 +1,28 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.enums import Neo4jNodeType
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.memory_service import MemoryContext
|
||||
from app.core.memory.models.service_models import Memory
|
||||
from app.core.memory.models.service_models import Memory, MemorySearchResult
|
||||
from app.core.memory.read_services.result_builder import data_builder_factory
|
||||
from app.core.models import RedBearEmbeddings
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemorySearchResult(BaseModel):
|
||||
memories: dict[str, list[dict]] = Field(default_factory=dict)
|
||||
content: str = Field(default="")
|
||||
count: int = Field(default=0)
|
||||
DEFAULT_ALPHA = 0.7
|
||||
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
|
||||
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
|
||||
class Neo4jSearchService:
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
|
||||
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: MemoryContext,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
embedder: RedBearEmbeddings,
|
||||
includes: list[Neo4jNodeType] | None = None,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
|
||||
@@ -44,7 +35,7 @@ class Neo4jSearchService:
|
||||
self.cosine_score_threshold = cosine_score_threshold
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
self.embedder: OpenAIEmbedderClient = embedder
|
||||
self.embedder: RedBearEmbeddings = embedder
|
||||
self.connector: Neo4jConnector | None = None
|
||||
|
||||
self.includes = includes
|
||||
@@ -121,9 +112,12 @@ class Neo4jSearchService:
|
||||
kw = float(combined[item_id].get("kw_score", 0) or 0)
|
||||
emb = float(combined[item_id].get("embedding_score", 0) or 0)
|
||||
base = self.alpha * emb + (1 - self.alpha) * kw
|
||||
combined[item_id]["content_score"] = base + min(1 - base, 0.1 * kw * emb)
|
||||
combined[item_id]["content_score"] = base + min(1 - base, kw * emb)
|
||||
results = sorted(combined.values(), key=lambda x: x["content_score"], reverse=True)
|
||||
# results = [res for res in results if res["content_score"] > self.content_score_threshold]
|
||||
# results = [
|
||||
# res for res in results
|
||||
# if res["content_score"] > self.content_score_threshold
|
||||
# ]
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
@@ -137,14 +131,14 @@ class Neo4jSearchService:
|
||||
return items
|
||||
scores = [float(it.get("score", 0) or 0) for it in items]
|
||||
for it, s in zip(items, scores):
|
||||
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2))
|
||||
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2)) if s else 0
|
||||
return items
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
) -> list[Memory]:
|
||||
) -> MemorySearchResult:
|
||||
async with Neo4jConnector() as connector:
|
||||
self.connector = connector
|
||||
kw_task = self._keyword_search(query, limit)
|
||||
@@ -175,4 +169,12 @@ class Neo4jSearchService:
|
||||
query=query
|
||||
))
|
||||
memories.sort(key=lambda x: x.score, reverse=True)
|
||||
return memories[:limit]
|
||||
return MemorySearchResult(memories=memories[:limit])
|
||||
|
||||
|
||||
class RAGSearchService:
|
||||
def __init__(self, ctx: MemoryContext):
|
||||
pass
|
||||
|
||||
async def search(self) -> MemorySearchResult:
|
||||
pass
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/4/9 16:48
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.memory_service import MemoryContext
|
||||
|
||||
|
||||
class ContentSearch:
|
||||
def __init__(self, ctx: MemoryContext):
|
||||
self.ctx = ctx
|
||||
|
||||
async def search(self, query):
|
||||
pass
|
||||
@@ -1,228 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.memory_service import MemoryContext
|
||||
from app.core.memory.utils.data import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import search_perceptual, search_perceptual_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerceptualResult(BaseModel):
|
||||
memories: list[dict[str, Any]] = []
|
||||
content: str = ""
|
||||
keyword_raw: int = 0
|
||||
embedding_raw: int = 0
|
||||
|
||||
|
||||
class PerceptualRetrieverService:
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 0.5
|
||||
DEFAULT_COSINE_SCORE_THRESHOLD = 0.7
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: MemoryContext,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
|
||||
cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.alpha = alpha
|
||||
self.fulltext_score_threshold = fulltext_score_threshold
|
||||
self.cosine_score_threshold = cosine_score_threshold
|
||||
|
||||
self.embedder: OpenAIEmbedderClient = embedder
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: list[str] | None = None,
|
||||
limit: int = 10
|
||||
) -> PerceptualResult:
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
try:
|
||||
kw_task = self._keyword_search(keywords, limit)
|
||||
emb_task = self._embedding_search(query, limit)
|
||||
kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return PerceptualResult(
|
||||
memories=memories,
|
||||
content="\n\n".join(content_parts),
|
||||
keyword_raw=len(kw_results),
|
||||
embedding_raw=len(emb_results),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[PerceptualSearch] search failed: {e}", exc_info=True)
|
||||
return PerceptualResult()
|
||||
finally:
|
||||
await self.connector.close()
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int
|
||||
) -> list[dict]:
|
||||
seen_ids: set = set()
|
||||
all_results: list[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=self.connector, q=escaped,
|
||||
end_user_id=self.ctx.end_user_id, limit=limit
|
||||
)
|
||||
perceptuals = r.get("perceptuals", [])
|
||||
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.fulltext_score_threshold]
|
||||
|
||||
tasks = [_one(kw) for kw in keywords]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int
|
||||
) -> list[dict]:
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder,
|
||||
query_text=query,
|
||||
end_user_id=self.ctx.end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
perceptuals = r.get("perceptuals", [])
|
||||
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.cosine_score_threshold]
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: list[dict],
|
||||
embedding_results: list[dict],
|
||||
limit: int,
|
||||
) -> list[dict]:
|
||||
keyword_results = self._normalize_scores(keyword_results)
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
kw_norm_map = {}
|
||||
for item in keyword_results:
|
||||
item_id = item["id"]
|
||||
kw_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
emb_norm_map = {}
|
||||
for item in embedding_results:
|
||||
item_id = item["id"]
|
||||
emb_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
combined = {}
|
||||
for item in keyword_results:
|
||||
item_id = item["id"]
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item["id"]
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
|
||||
for item in combined.values():
|
||||
kw = float(item.get("kw_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * emb + (1 - self.alpha) * kw
|
||||
|
||||
results = list(combined.values())
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={len(combined)}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: list[dict], field: str = "score") -> list[dict]:
|
||||
"""Min-max 归一化,将分数线性映射到 [0, 1]。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
diff = max_s - min_s
|
||||
for it, s in zip(items, scores):
|
||||
it[f"normalized_{field}"] = (s - min_s) / diff if diff > 0 else 1.0
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
content_text = (f"<history-file-info>"
|
||||
f"<file-name>{formatted["file_name"]}</file-name>"
|
||||
f"<file-path>{formatted["file_path"]}</file-path>"
|
||||
f"<file-type>{formatted["file_type"]}</file-type>"
|
||||
f"<file-topic>{formatted["topic"]}</file-topic>"
|
||||
f"<file-domain>{formatted["keywords"]}</file-domain>"
|
||||
f"<file-summary>{formatted["summary"]}</file-summary>"
|
||||
f"</history-file-info>")
|
||||
return content_text
|
||||
@@ -1,11 +1,13 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/4/8 18:11
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.core.memory.prompt import prompt_manager
|
||||
from app.core.memory.utils.llm.llm_utils import StructResponse
|
||||
from app.core.models import RedBearLLM
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryPreprocessor:
|
||||
@staticmethod
|
||||
@@ -16,3 +18,22 @@ class QueryPreprocessor:
|
||||
|
||||
text = re.sub(rf"{"|".join(AgentMemoryDataset.PRONOUN)}", AgentMemoryDataset.NAME, text)
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
async def split(query: str, llm_client: RedBearLLM):
|
||||
system_prompt = prompt_manager.render(
|
||||
name="problem_split",
|
||||
history=[],
|
||||
sentence=query,
|
||||
)
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
try:
|
||||
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
|
||||
except Exception as e:
|
||||
logger.error(f"[QueryPreprocessor] Sub-question segmentation failed - {e}")
|
||||
sub_queries = None
|
||||
return sub_queries or query
|
||||
|
||||
@staticmethod
|
||||
async def extension(query: str, llm_client: RedBearLLM):
|
||||
pass
|
||||
|
||||
@@ -114,7 +114,7 @@ class PerceptualBuilder(BaseBuilder):
|
||||
f"<domain>{self.record.get('domain')}</domain>"
|
||||
f"<keywords>{self.record.get('keywords')}</keywords>"
|
||||
f"<file-type>{self.record.get('file_type')}</file-type>"
|
||||
"</<history-file-info>")
|
||||
"</history-file-info>")
|
||||
|
||||
|
||||
class CommunityBuilder(BaseBuilder):
|
||||
|
||||
11
api/app/core/memory/read_services/retrieval_summary.py
Normal file
11
api/app/core/memory/read_services/retrieval_summary.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from app.core.models import RedBearLLM
|
||||
|
||||
|
||||
class RetrievalSummaryProcessor:
|
||||
@staticmethod
|
||||
def summary(content: str, llm_client: RedBearLLM):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def verify(content: str, llm_client: RedBearLLM):
|
||||
return
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Literal, Type
|
||||
|
||||
from json_repair import json_repair
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -13,6 +16,27 @@ async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class StructResponse:
|
||||
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
|
||||
self.mode = mode
|
||||
if mode == "pydantic" and model is None:
|
||||
raise ValueError("Pydantic model is required")
|
||||
|
||||
self.model = model
|
||||
|
||||
def __ror__(self, other: AIMessage):
|
||||
if not isinstance(other, AIMessage):
|
||||
raise RuntimeError(f"Unsupported struct type {type(other)}")
|
||||
text = ''
|
||||
for block in other.content_blocks:
|
||||
if block.get("type") == "text":
|
||||
text += block.get("text", "")
|
||||
fixed_json = json_repair.repair_json(text, return_objects=True)
|
||||
if self.mode == "json":
|
||||
return fixed_json
|
||||
return self.model.model_validate(fixed_json)
|
||||
|
||||
|
||||
class MemoryClientFactory:
|
||||
"""
|
||||
Factory for creating LLM, embedder, and reranker clients.
|
||||
@@ -24,21 +48,21 @@ class MemoryClientFactory:
|
||||
>>> llm_client = factory.get_llm_client(model_id)
|
||||
>>> embedder_client = factory.get_embedder_client(embedding_id)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
self._config_service = MemoryConfigService(db)
|
||||
|
||||
|
||||
def get_llm_client(self, llm_id: str) -> OpenAIClient:
|
||||
"""Get LLM client by model ID."""
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required")
|
||||
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
@@ -52,19 +76,19 @@ class MemoryClientFactory:
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
|
||||
def get_embedder_client(self, embedding_id: str):
|
||||
"""Get embedder client by model ID."""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
|
||||
|
||||
if not embedding_id:
|
||||
raise ValueError("Embedding ID is required")
|
||||
|
||||
|
||||
try:
|
||||
embedder_config = self._config_service.get_embedder_config(embedding_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
|
||||
|
||||
|
||||
try:
|
||||
return OpenAIEmbedderClient(
|
||||
RedBearModelConfig(
|
||||
@@ -77,17 +101,17 @@ class MemoryClientFactory:
|
||||
except Exception as e:
|
||||
model_name = embedder_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
|
||||
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
|
||||
"""Get reranker client by model ID."""
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required")
|
||||
|
||||
|
||||
try:
|
||||
model_config = self._config_service.get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
|
||||
try:
|
||||
return OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
|
||||
@@ -8,6 +8,7 @@ import numpy as np
|
||||
from app.core.memory.enums import Neo4jNodeType
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.core.models import RedBearEmbeddings
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -358,7 +359,7 @@ async def search_by_embedding(
|
||||
USER_ID_QUERY_CYPHER_MAPPING[node_type],
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
records = [record for record in records if record if record["embedding"] is not None]
|
||||
records = [record for record in records if record and record.get("embedding") is not None]
|
||||
ids = [item['id'] for item in records]
|
||||
vectors = [item['embedding'] for item in records]
|
||||
sim_res = cosine_similarity_search(query_embedding, vectors, limit=limit)
|
||||
@@ -469,7 +470,7 @@ async def search_graph(
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
embedder_client: RedBearEmbeddings | OpenAIEmbedderClient,
|
||||
query_text: str,
|
||||
end_user_id: str,
|
||||
limit: int = 50,
|
||||
@@ -495,7 +496,10 @@ async def search_graph_by_embedding(
|
||||
Neo4jNodeType.PERCEPTUAL
|
||||
]
|
||||
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
if isinstance(embedder_client, RedBearEmbeddings):
|
||||
embeddings = embedder_client.embed_documents([query_text])
|
||||
else:
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(f"search_graph_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||
return {search_key: [] for search_key in include}
|
||||
|
||||
@@ -34,7 +34,7 @@ Readability Guideline: Ensure optimized prompts have good readability and logica
|
||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
|
||||
|
||||
Constraints
|
||||
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
|
||||
Output Constraint: Must output in JSON format including the string fields "prompt" and "desc".
|
||||
Content Constraint: Must not include any explanations, analyses, or additional comments.
|
||||
Language Constraint: Must use clear and concise language.
|
||||
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
|
||||
|
||||
Reference in New Issue
Block a user