[ADD]Add functions related to knowledge base graph:
Add functions related to knowledge base graph: 1. Entity type generation, 2. Knowledge base graph acquisition, 3. Hard deletion of knowledge base graph, 4. Knowledge base graph reconstruction (asynchronous)
This commit is contained in:
290
api/app/core/rag/llm/embedding_model.py
Normal file
290
api/app/core/rag/llm/embedding_model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import numpy as np
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from app.core.rag.common.log_utils import log_exception
|
||||
from app.core.rag.common.token_utils import num_tokens_from_string, truncate
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Constructor for abstract base class.
|
||||
Parameters are accepted for interface consistency but are not stored.
|
||||
Subclasses should implement their own initialization as needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def encode(self, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local embedding model url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
# local embedding for LmStudio donot count tokens
|
||||
return np.array(ress), 1024
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class AzureEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import time
|
||||
|
||||
import dashscope
|
||||
|
||||
batch_size = 4
|
||||
res = []
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
|
||||
time.sleep(10)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
retry_max -= 1
|
||||
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
|
||||
if resp.get("message"):
|
||||
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
|
||||
else:
|
||||
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
|
||||
raise
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
embds[e["text_index"]] = e["embedding"]
|
||||
res.extend(embds)
|
||||
token_count += self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||
self.api_key = key
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
self.model_name = model_name
|
||||
if model_name == "nvidia/embed-qa-4":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
||||
self.model_name = "NV-Embed-QA"
|
||||
if model_name == "snowflake/arctic-embed-l":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
payload = {
|
||||
"input": texts[i : i + batch_size],
|
||||
"input_type": "query",
|
||||
"model": self.model_name,
|
||||
"encoding_format": "float",
|
||||
"truncate": "END",
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
self.key = key
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||
|
||||
def encode(self, texts: list):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embeddings = response.json()
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()[0]
|
||||
return np.array(embedding), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
Reference in New Issue
Block a user