Files
MemoryBear/api/app/core/models/rerank.py

85 lines
3.5 KiB
Python

from typing import Any, Dict, List, Optional, Sequence, Type, Union
from copy import deepcopy
from urllib.parse import urlparse
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.runnables import RunnableSerializable
from langchain_core.callbacks import Callbacks
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
from app.models import ModelProvider
class RedBearRerank(BaseDocumentCompressor):
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config
def _create_model(self, config: RedBearModelConfig):
"""创建内部模型实例"""
model_class = get_provider_rerank_class(config.provider)
model_params = RedBearModelFactory.get_rerank_model_params(config)
print(model_params)
return model_class(**model_params)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Jina's Rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]:
provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import langchain_community.document_compressors.jina_rerank as jina_mod
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
if not base_url:
return None
url = base_url.rstrip('/')
if url.endswith("/v1/rerank"):
return url
if url.endswith("/v1"):
return url + "/rerank"
return url + "/v1/rerank"
jina_base = _normalize_jina_base(self._config.base_url)
if jina_base:
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
jina_mod.JINA_API_URL = jina_base
from langchain_community.document_compressors import JinaRerank
model_instance: JinaRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
model_instance: DashScopeRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else:
raise ValueError(f"不支持的模型提供商: {provider}")