refactor(memory): consolidate memory search services and update model client handling

- Consolidate memory search services by removing separate content_search.py and perceptual_search.py
- Update model client handling in base_pipeline.py to use ModelApiKeyService for LLM client initialization
- Add new prompt files and modify existing services to support consolidated search architecture
- Refactor memory read pipeline and related services to use updated model client approach
This commit is contained in:
Eternity
2026-04-16 13:27:36 +08:00
parent 2716a55c7f
commit a01525e239
16 changed files with 471 additions and 309 deletions

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from app.core.memory.enums import StorageType, SearchStrategy
from app.core.memory.models.service_models import Memory, MemoryContext
from app.core.memory.models.service_models import MemoryContext, MemorySearchResult
from app.core.memory.pipelines.memory_read import ReadPipeLine
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
@@ -35,9 +35,14 @@ class MemoryService:
async def write(self, messages: list[dict]) -> str:
raise NotImplementedError
async def read(self, query: str, history: list, search_switch: SearchStrategy) -> list[Memory]:
async def read(
self,
query: str,
search_switch: SearchStrategy,
limit: int = 10,
) -> MemorySearchResult:
with get_db_context() as db:
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit=10)
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from pydantic import BaseModel, Field, field_serializer, ConfigDict, model_validator, computed_field
from app.core.memory.enums import Neo4jNodeType, StorageType
from app.core.validators import file_validator
from app.schemas.memory_config_schema import MemoryConfig
@@ -24,3 +25,17 @@ class Memory(BaseModel):
@field_serializer("source")
def serialize_source(self, v) -> str:
return v.value
class MemorySearchResult(BaseModel):
memories: list[Memory]
@computed_field
@property
def content(self) -> str:
return "\n".join([memory.content for memory in self.memories])
@computed_field
@property
def count(self) -> int:
return len(self.memories)

View File

@@ -4,22 +4,32 @@ from typing import Any
from sqlalchemy.orm import Session
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.models.service_models import MemoryContext
from app.core.models import RedBearModelConfig
from app.core.models import RedBearModelConfig, RedBearLLM, RedBearEmbeddings
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelApiKeyService
class ModelClientMixin(ABC):
@staticmethod
def get_llm_client(db: Session, model_id: uuid.UUID):
pass
def get_llm_client(db: Session, model_id: uuid.UUID) -> RedBearLLM:
api_config = ModelApiKeyService.get_available_api_key(db, model_id)
return RedBearLLM(
RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base,
is_omni=api_config.is_omni,
support_thinking="thinking" in (api_config.capability or []),
)
)
@staticmethod
def get_embedding_client(db: Session, model_id: uuid.UUID) -> OpenAIEmbedderClient:
def get_embedding_client(db: Session, model_id: uuid.UUID) -> RedBearEmbeddings:
config_service = MemoryConfigService(db)
embedder_client_config = config_service.get_embedder_config(str(model_id))
return OpenAIEmbedderClient(
return RedBearEmbeddings(
RedBearModelConfig(
model_name=embedder_client_config["model_name"],
provider=embedder_client_config["provider"],
@@ -30,10 +40,15 @@ class ModelClientMixin(ABC):
class BasePipeline(ABC):
def __init__(self, ctx: MemoryContext, db: Session):
def __init__(self, ctx: MemoryContext):
self.ctx = ctx
self.db = db
@abstractmethod
async def run(self, *args, **kwargs) -> Any:
pass
class DBRequiredPipeline(BasePipeline, ABC):
def __init__(self, ctx: MemoryContext, db: Session):
super().__init__(ctx)
self.db = db

View File

@@ -1,31 +1,41 @@
from app.core.memory.enums import SearchStrategy
from app.core.memory.pipelines.base_pipeline import BasePipeline, ModelClientMixin
from app.core.memory.read_services.content_search import Neo4jSearchService
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
class ReadPipeLine(ModelClientMixin, BasePipeline):
async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10):
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10, includes=None) -> MemorySearchResult:
query = QueryPreprocessor.process(query)
if self.ctx.storage_type == StorageType.RAG:
return await self._rag_read(query, limit)
match search_switch:
case SearchStrategy.DEEP:
return await self._deep_read()
return await self._deep_read(query, limit, includes)
case SearchStrategy.NORMAL:
return await self._normal_read(query)
return await self._normal_read(query, limit, includes)
case SearchStrategy.QUICK:
return await self._quick_read(query, limit)
return await self._quick_read(query, limit, includes)
case _:
raise RuntimeError("Unsupported search strategy")
async def _deep_read(self):
async def _rag_read(self, query: str, limit: int) -> MemorySearchResult:
service = RAGSearchService(
self.ctx
)
return await service.search()
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
pass
async def _normal_read(self, query):
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
pass
async def _quick_read(self, query, limit):
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
search_service = Neo4jSearchService(
self.ctx,
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id)
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id),
includes=includes,
)
return await search_service.search(query, limit)

View File

@@ -0,0 +1,85 @@
import logging
import threading
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, TemplateSyntaxError
logger = logging.getLogger(__name__)
PROMPT_DIR = Path(__file__).parent
class PromptRenderError(Exception):
def __init__(self, template_name: str, error: Exception):
self.template_name = template_name
self.error = error
super().__init__(f"Failed to render prompt '{template_name}': {error}")
class PromptManager:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init_once()
return cls._instance
def _init_once(self):
self.env = Environment(
loader=FileSystemLoader(str(PROMPT_DIR)),
autoescape=False,
keep_trailing_newline=True,
)
logger.info(f"PromptManager initialized: template_dir={PROMPT_DIR}")
def __repr__(self):
templates = self.list_templates()
return f"<PromptManager: {len(templates)} prompts: {templates}>"
def list_templates(self) -> list[str]:
return [
Path(name).stem
for name in self.env.loader.list_templates()
if name.endswith('.jinja2')
]
def get(self, name: str) -> str:
template_name = self._resolve_name(name)
try:
source, _, _ = self.env.loader.get_source(self.env, template_name)
return source
except TemplateNotFound:
raise FileNotFoundError(
f"Prompt '{name}' not found. "
f"Available: {self.list_templates()}"
)
def render(self, name: str, **kwargs) -> str:
template_name = self._resolve_name(name)
try:
template = self.env.get_template(template_name)
return template.render(**kwargs)
except TemplateNotFound:
raise FileNotFoundError(
f"Prompt '{name}' not found. "
f"Available: {self.list_templates()}"
)
except TemplateSyntaxError as e:
logger.error(f"Prompt syntax error in '{name}': {e}", exc_info=True)
raise PromptRenderError(name, e)
except Exception as e:
logger.error(f"Prompt render failed for '{name}': {e}", exc_info=True)
raise PromptRenderError(name, e)
@staticmethod
def _resolve_name(name: str) -> str:
if not name.endswith('.jinja2'):
return f"{name}.jinja2"
return name
prompt_manager = PromptManager()

View File

@@ -0,0 +1,212 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
## 目标:
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{{ sentence }}
## 需求:
1:首先判断类型(单跳、多跳、开放域、时间)。
2:根据类型进行拆分。
3:拆分后的内容需保证信息完整且可独立处理。
4:对每个拆分条目,可附加示例或说明。
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指代消歧规则Coreference Resolution
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
1. **"用户"的消歧**
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
2. **"我"的消歧**
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么"
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
3. **"他/她/它"的消歧**
- 从上下文或历史中找出最近提到的同类实体
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
4. **"那个人/这个人"的消歧**
- 从历史中找出最近提到的人物
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
5. **优先级**
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
- 如果无法从历史中确定指代对象保留原问题但在reason中说明"无法确定指代对象"
## 指令:
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
单跳Single-hop
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
示例:
输入数据:"请列出今年诺贝尔物理学奖的得主"
拆分结果:[
{
"id": "Q1",
"question": "今年诺贝尔物理学奖得主是谁",
"type": "单跳’",
}
]
注意: 当遇到上下文依赖问题时明确指出缺失的信息类型并且question可填写输入问题
多跳Multi-hop:
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
示例:
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
拆分结果:
[
{
"id": "Q1",
"question": 今年诺贝尔物理学奖得主是谁?",
"type": "多跳’",
},
{
"id": "Q2",
"question": "该得主的研究领域是什么?",
"type": "多跳’",
},
{
"id": "Q3",
"question": "该得主的代表性成果有哪些?",
"type": "多跳’"
}
]
开放域Open-domain:
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
示例:
输入数据:"介绍量子计算的最新研究进展"
拆分结果:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
时间Temporal:
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
示例:
输入数据:"列出苹果公司过去五年的重大事件"
拆分结果:
[
{
"id": "Q1",
"question": 苹果公司2019年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q2",
"question": "苹果公司2020年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2021年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2022年的重大事件有哪些",
"type": "时间’",
}
,
{
"id": "Q4",
"question": "苹果公司2023年的重大事件有哪些",
"type": "时间’",
}
]
输出要求:
- 每个子问题包括:
- `id`: 子问题编号Q1, Q2...
- `question`: 子问题内容
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
- `reason`: 拆分的理由(为什么要这样拆)
- 格式案例:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
- 必须通过json.loads()的格式支持的形式输出
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
## 指代消歧示例(重要):
示例1 - "用户"的消歧:
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
输入问题:"用户是谁?"
输出:
[
{
"id": "Q1",
"question": "李建国是谁?",
"type": "单跳",
"reason": "历史中反复提到'老李/李建国/建国哥''用户'指的就是对话发起者李建国"
}
]
示例2 - "我"的消歧:
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
输入问题:"我推荐的书是什么?"
输出:
[
{
"id": "Q1",
"question": "张曼玉推荐的书是什么?",
"type": "单跳",
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
}
]
- 关键的JSON格式要求
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -1,37 +1,28 @@
import asyncio
import logging
import math
import time
from pydantic import BaseModel, Field
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.memory_service import MemoryContext
from app.core.memory.models.service_models import Memory
from app.core.memory.models.service_models import Memory, MemorySearchResult
from app.core.memory.read_services.result_builder import data_builder_factory
from app.core.models import RedBearEmbeddings
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class MemorySearchResult(BaseModel):
memories: dict[str, list[dict]] = Field(default_factory=dict)
content: str = Field(default="")
count: int = Field(default=0)
DEFAULT_ALPHA = 0.7
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
class Neo4jSearchService:
DEFAULT_ALPHA = 0.6
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
ctx: MemoryContext,
embedder: OpenAIEmbedderClient,
embedder: RedBearEmbeddings,
includes: list[Neo4jNodeType] | None = None,
alpha: float = DEFAULT_ALPHA,
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
@@ -44,7 +35,7 @@ class Neo4jSearchService:
self.cosine_score_threshold = cosine_score_threshold
self.content_score_threshold = content_score_threshold
self.embedder: OpenAIEmbedderClient = embedder
self.embedder: RedBearEmbeddings = embedder
self.connector: Neo4jConnector | None = None
self.includes = includes
@@ -121,9 +112,12 @@ class Neo4jSearchService:
kw = float(combined[item_id].get("kw_score", 0) or 0)
emb = float(combined[item_id].get("embedding_score", 0) or 0)
base = self.alpha * emb + (1 - self.alpha) * kw
combined[item_id]["content_score"] = base + min(1 - base, 0.1 * kw * emb)
combined[item_id]["content_score"] = base + min(1 - base, kw * emb)
results = sorted(combined.values(), key=lambda x: x["content_score"], reverse=True)
# results = [res for res in results if res["content_score"] > self.content_score_threshold]
# results = [
# res for res in results
# if res["content_score"] > self.content_score_threshold
# ]
results = results[:limit]
logger.info(
@@ -137,14 +131,14 @@ class Neo4jSearchService:
return items
scores = [float(it.get("score", 0) or 0) for it in items]
for it, s in zip(items, scores):
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2))
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2)) if s else 0
return items
async def search(
self,
query: str,
limit: int = 10,
) -> list[Memory]:
) -> MemorySearchResult:
async with Neo4jConnector() as connector:
self.connector = connector
kw_task = self._keyword_search(query, limit)
@@ -175,4 +169,12 @@ class Neo4jSearchService:
query=query
))
memories.sort(key=lambda x: x.score, reverse=True)
return memories[:limit]
return MemorySearchResult(memories=memories[:limit])
class RAGSearchService:
def __init__(self, ctx: MemoryContext):
pass
async def search(self) -> MemorySearchResult:
pass

View File

@@ -1,14 +0,0 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/4/9 16:48
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.memory_service import MemoryContext
class ContentSearch:
def __init__(self, ctx: MemoryContext):
self.ctx = ctx
async def search(self, query):
pass

View File

@@ -1,228 +0,0 @@
import asyncio
import logging
from typing import Any
from pydantic import BaseModel
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.memory_service import MemoryContext
from app.core.memory.utils.data import escape_lucene_query
from app.repositories.neo4j.graph_search import search_perceptual, search_perceptual_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class PerceptualResult(BaseModel):
memories: list[dict[str, Any]] = []
content: str = ""
keyword_raw: int = 0
embedding_raw: int = 0
class PerceptualRetrieverService:
DEFAULT_ALPHA = 0.6
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 0.5
DEFAULT_COSINE_SCORE_THRESHOLD = 0.7
def __init__(
self,
ctx: MemoryContext,
embedder: OpenAIEmbedderClient,
alpha: float = DEFAULT_ALPHA,
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD
):
self.ctx = ctx
self.alpha = alpha
self.fulltext_score_threshold = fulltext_score_threshold
self.cosine_score_threshold = cosine_score_threshold
self.embedder: OpenAIEmbedderClient = embedder
self.connector = Neo4jConnector()
async def search(
self,
query: str,
keywords: list[str] | None = None,
limit: int = 10
) -> PerceptualResult:
if keywords is None:
keywords = [query] if query else []
try:
kw_task = self._keyword_search(keywords, limit)
emb_task = self._embedding_search(query, limit)
kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return PerceptualResult(
memories=memories,
content="\n\n".join(content_parts),
keyword_raw=len(kw_results),
embedding_raw=len(emb_results),
)
except Exception as e:
logger.error(f"[PerceptualSearch] search failed: {e}", exc_info=True)
return PerceptualResult()
finally:
await self.connector.close()
async def _keyword_search(
self,
keywords: list[str],
limit: int
) -> list[dict]:
seen_ids: set = set()
all_results: list[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=self.connector, q=escaped,
end_user_id=self.ctx.end_user_id, limit=limit
)
perceptuals = r.get("perceptuals", [])
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.fulltext_score_threshold]
tasks = [_one(kw) for kw in keywords]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
query: str,
limit: int
) -> list[dict]:
r = await search_perceptual_by_embedding(
connector=self.connector,
embedder_client=self.embedder,
query_text=query,
end_user_id=self.ctx.end_user_id,
limit=limit
)
perceptuals = r.get("perceptuals", [])
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.cosine_score_threshold]
def _rerank(
self,
keyword_results: list[dict],
embedding_results: list[dict],
limit: int,
) -> list[dict]:
keyword_results = self._normalize_scores(keyword_results)
embedding_results = self._normalize_scores(embedding_results)
kw_norm_map = {}
for item in keyword_results:
item_id = item["id"]
kw_norm_map[item_id] = float(item.get("normalized_score", 0))
emb_norm_map = {}
for item in embedding_results:
item_id = item["id"]
emb_norm_map[item_id] = float(item.get("normalized_score", 0))
combined = {}
for item in keyword_results:
item_id = item["id"]
combined[item_id] = item.copy()
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
for item in embedding_results:
item_id = item["id"]
if item_id in combined:
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
else:
combined[item_id] = item.copy()
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
for item in combined.values():
kw = float(item.get("kw_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * emb + (1 - self.alpha) * kw
results = list(combined.values())
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={len(combined)}, after_threshold={len(results)} "
f"(alpha={self.alpha})"
)
return results
@staticmethod
def _normalize_scores(items: list[dict], field: str = "score") -> list[dict]:
"""Min-max 归一化,将分数线性映射到 [0, 1]。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
min_s = min(scores)
max_s = max(scores)
diff = max_s - min_s
for it, s in zip(items, scores):
it[f"normalized_{field}"] = (s - min_s) / diff if diff > 0 else 1.0
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
content_text = (f"<history-file-info>"
f"<file-name>{formatted["file_name"]}</file-name>"
f"<file-path>{formatted["file_path"]}</file-path>"
f"<file-type>{formatted["file_type"]}</file-type>"
f"<file-topic>{formatted["topic"]}</file-topic>"
f"<file-domain>{formatted["keywords"]}</file-domain>"
f"<file-summary>{formatted["summary"]}</file-summary>"
f"</history-file-info>")
return content_text

View File

@@ -1,11 +1,13 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/4/8 18:11
import logging
import re
from app.core.memory.prompt import prompt_manager
from app.core.memory.utils.llm.llm_utils import StructResponse
from app.core.models import RedBearLLM
from app.schemas.memory_agent_schema import AgentMemoryDataset
logger = logging.getLogger(__name__)
class QueryPreprocessor:
@staticmethod
@@ -16,3 +18,22 @@ class QueryPreprocessor:
text = re.sub(rf"{"|".join(AgentMemoryDataset.PRONOUN)}", AgentMemoryDataset.NAME, text)
return text
@staticmethod
async def split(query: str, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="problem_split",
history=[],
sentence=query,
)
messages = [{"role": "system", "content": system_prompt}]
try:
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
except Exception as e:
logger.error(f"[QueryPreprocessor] Sub-question segmentation failed - {e}")
sub_queries = None
return sub_queries or query
@staticmethod
async def extension(query: str, llm_client: RedBearLLM):
pass

View File

@@ -114,7 +114,7 @@ class PerceptualBuilder(BaseBuilder):
f"<domain>{self.record.get('domain')}</domain>"
f"<keywords>{self.record.get('keywords')}</keywords>"
f"<file-type>{self.record.get('file_type')}</file-type>"
"</<history-file-info>")
"</history-file-info>")
class CommunityBuilder(BaseBuilder):

View File

@@ -0,0 +1,11 @@
from app.core.models import RedBearLLM
class RetrievalSummaryProcessor:
@staticmethod
def summary(content: str, llm_client: RedBearLLM):
return
@staticmethod
def verify(content: str, llm_client: RedBearLLM):
return

View File

@@ -1,4 +1,7 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, Type
from json_repair import json_repair
from langchain_core.messages import AIMessage
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
@@ -13,6 +16,27 @@ async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump()
class StructResponse:
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
self.mode = mode
if mode == "pydantic" and model is None:
raise ValueError("Pydantic model is required")
self.model = model
def __ror__(self, other: AIMessage):
if not isinstance(other, AIMessage):
raise RuntimeError(f"Unsupported struct type {type(other)}")
text = ''
for block in other.content_blocks:
if block.get("type") == "text":
text += block.get("text", "")
fixed_json = json_repair.repair_json(text, return_objects=True)
if self.mode == "json":
return fixed_json
return self.model.model_validate(fixed_json)
class MemoryClientFactory:
"""
Factory for creating LLM, embedder, and reranker clients.
@@ -24,21 +48,21 @@ class MemoryClientFactory:
>>> llm_client = factory.get_llm_client(model_id)
>>> embedder_client = factory.get_embedder_client(embedding_id)
"""
def __init__(self, db: Session):
from app.services.memory_config_service import MemoryConfigService
self._config_service = MemoryConfigService(db)
def get_llm_client(self, llm_id: str) -> OpenAIClient:
"""Get LLM client by model ID."""
if not llm_id:
raise ValueError("LLM ID is required")
try:
model_config = self._config_service.get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
@@ -52,19 +76,19 @@ class MemoryClientFactory:
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_embedder_client(self, embedding_id: str):
"""Get embedder client by model ID."""
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
if not embedding_id:
raise ValueError("Embedding ID is required")
try:
embedder_config = self._config_service.get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
return OpenAIEmbedderClient(
RedBearModelConfig(
@@ -77,17 +101,17 @@ class MemoryClientFactory:
except Exception as e:
model_name = embedder_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
"""Get reranker client by model ID."""
if not rerank_id:
raise ValueError("Rerank ID is required")
try:
model_config = self._config_service.get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(

View File

@@ -8,6 +8,7 @@ import numpy as np
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.core.models import RedBearEmbeddings
from app.repositories.neo4j.cypher_queries import (
EXPAND_COMMUNITY_STATEMENTS,
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -358,7 +359,7 @@ async def search_by_embedding(
USER_ID_QUERY_CYPHER_MAPPING[node_type],
end_user_id=end_user_id,
)
records = [record for record in records if record if record["embedding"] is not None]
records = [record for record in records if record and record.get("embedding") is not None]
ids = [item['id'] for item in records]
vectors = [item['embedding'] for item in records]
sim_res = cosine_similarity_search(query_embedding, vectors, limit=limit)
@@ -469,7 +470,7 @@ async def search_graph(
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
embedder_client: RedBearEmbeddings | OpenAIEmbedderClient,
query_text: str,
end_user_id: str,
limit: int = 50,
@@ -495,7 +496,10 @@ async def search_graph_by_embedding(
Neo4jNodeType.PERCEPTUAL
]
embeddings = await embedder_client.response([query_text])
if isinstance(embedder_client, RedBearEmbeddings):
embeddings = embedder_client.embed_documents([query_text])
else:
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_graph_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {search_key: [] for search_key in include}

View File

@@ -34,7 +34,7 @@ Readability Guideline: Ensure optimized prompts have good readability and logica
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
Constraints
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
Output Constraint: Must output in JSON format including the string fields "prompt" and "desc".
Content Constraint: Must not include any explanations, analyses, or additional comments.
Language Constraint: Must use clear and concise language.
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}