feat(model): add volcano model
This commit is contained in:
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
|
||||
class ElasticSearchVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
super().__init__(index_name.lower())
|
||||
# self.embeddings = XinferenceEmbeddings(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
|
||||
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
|
||||
|
||||
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=embedding_config.model_name,
|
||||
provider=embedding_config.provider,
|
||||
api_key=embedding_config.api_key,
|
||||
base_url=embedding_config.api_base
|
||||
))
|
||||
# self.reranker = XinferenceRerank(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
|
||||
# model_uid="bge-reranker-large"
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
|
||||
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
|
||||
|
||||
self.reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=reranker_config.model_name,
|
||||
provider=reranker_config.provider,
|
||||
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
|
||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||
# 实现 Elasticsearch 保存向量
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = self.embeddings.embed_batch(texts)
|
||||
else:
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
self.create(chunks, embeddings, **kwargs)
|
||||
|
||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
|
||||
updated count.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||
else:
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
|
||||
body = {
|
||||
"script": {
|
||||
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
||||
"""Search the nearest neighbors to a vector."""
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
query_vector = self.embeddings.embed_text(query)
|
||||
else:
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
top_k = kwargs.get("top_k", 1024)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
|
||||
Reference in New Issue
Block a user