|
|
|
|
@@ -8,8 +8,6 @@ import requests
|
|
|
|
|
from elasticsearch import Elasticsearch, helpers
|
|
|
|
|
from elasticsearch.helpers import BulkIndexError
|
|
|
|
|
from packaging.version import parse as parse_version
|
|
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
|
from abc import ABC
|
|
|
|
|
# langchain-community
|
|
|
|
|
# langchain-xinference
|
|
|
|
|
# from langchain_community.embeddings import XinferenceEmbeddings
|
|
|
|
|
@@ -29,37 +27,9 @@ from app.core.rag.models.chunk import DocumentChunk
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ElasticSearchConfig(BaseModel):
|
|
|
|
|
# Regular Elasticsearch config
|
|
|
|
|
host: str | None = None
|
|
|
|
|
port: int | None = None
|
|
|
|
|
username: str | None = None
|
|
|
|
|
password: str | None = None
|
|
|
|
|
|
|
|
|
|
# Common config
|
|
|
|
|
ca_certs: str | None = None
|
|
|
|
|
verify_certs: bool = False
|
|
|
|
|
request_timeout: int = 100000
|
|
|
|
|
retry_on_timeout: bool = True
|
|
|
|
|
max_retries: int = 10000
|
|
|
|
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def validate_config(cls, values: dict):
|
|
|
|
|
# Regular Elasticsearch validation
|
|
|
|
|
if not values.get("host"):
|
|
|
|
|
raise ValueError("config HOST is required for regular Elasticsearch")
|
|
|
|
|
if not values.get("port"):
|
|
|
|
|
raise ValueError("config PORT is required for regular Elasticsearch")
|
|
|
|
|
if not values.get("username"):
|
|
|
|
|
raise ValueError("config USERNAME is required for regular Elasticsearch")
|
|
|
|
|
if not values.get("password"):
|
|
|
|
|
raise ValueError("config PASSWORD is required for regular Elasticsearch")
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ElasticSearchVector(BaseVector):
|
|
|
|
|
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
|
|
|
|
def __init__(self, index_name: str, client: Elasticsearch,
|
|
|
|
|
embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
|
|
|
|
super().__init__(index_name.lower())
|
|
|
|
|
|
|
|
|
|
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
|
|
|
|
@@ -77,58 +47,8 @@ class ElasticSearchVector(BaseVector):
|
|
|
|
|
api_key=reranker_config.api_key,
|
|
|
|
|
base_url=reranker_config.api_base
|
|
|
|
|
))
|
|
|
|
|
self._client = self._init_client(config)
|
|
|
|
|
self._version = self._get_version()
|
|
|
|
|
self._check_version()
|
|
|
|
|
|
|
|
|
|
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
|
|
|
|
"""
|
|
|
|
|
Initialize Elasticsearch client for regular Elasticsearch.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Regular Elasticsearch configuration
|
|
|
|
|
parsed_url = urlparse(config.host or "")
|
|
|
|
|
if parsed_url.scheme in {"http", "https"}:
|
|
|
|
|
hosts = f"{config.host}:{config.port}"
|
|
|
|
|
use_https = parsed_url.scheme == "https"
|
|
|
|
|
else:
|
|
|
|
|
hosts = f"https://{config.host}:{config.port}"
|
|
|
|
|
use_https = False
|
|
|
|
|
|
|
|
|
|
client_config = {
|
|
|
|
|
"hosts": [hosts],
|
|
|
|
|
"basic_auth": (config.username, config.password),
|
|
|
|
|
"request_timeout": config.request_timeout,
|
|
|
|
|
"retry_on_timeout": config.retry_on_timeout,
|
|
|
|
|
"max_retries": config.max_retries,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Only add SSL settings if using HTTPS
|
|
|
|
|
if use_https:
|
|
|
|
|
client_config["verify_certs"] = config.verify_certs
|
|
|
|
|
if config.ca_certs:
|
|
|
|
|
client_config["ca_certs"] = config.ca_certs
|
|
|
|
|
|
|
|
|
|
client = Elasticsearch(**client_config)
|
|
|
|
|
|
|
|
|
|
# Test connection
|
|
|
|
|
if not client.ping():
|
|
|
|
|
raise ConnectionError("Failed to connect to Elasticsearch")
|
|
|
|
|
|
|
|
|
|
except requests.ConnectionError as e:
|
|
|
|
|
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
return client
|
|
|
|
|
|
|
|
|
|
def _get_version(self) -> str:
|
|
|
|
|
info = self._client.info()
|
|
|
|
|
return cast(str, info["version"]["number"])
|
|
|
|
|
|
|
|
|
|
def _check_version(self):
|
|
|
|
|
if parse_version(self._version) < parse_version("8.0.0"):
|
|
|
|
|
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
|
|
|
|
# 使用外部传入的共享客户端
|
|
|
|
|
self._client = client
|
|
|
|
|
|
|
|
|
|
def get_type(self) -> str:
|
|
|
|
|
return "elasticsearch"
|
|
|
|
|
@@ -744,30 +664,83 @@ class ElasticSearchVector(BaseVector):
|
|
|
|
|
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import threading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ElasticSearchVectorFactory:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
|
|
|
|
"""ES 向量服务工厂 - 单例共享连接"""
|
|
|
|
|
|
|
|
|
|
_client: Elasticsearch | None = None
|
|
|
|
|
_lock = threading.Lock()
|
|
|
|
|
_version_checked = False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _get_shared_client(cls) -> Elasticsearch:
|
|
|
|
|
"""获取共享的 ES 客户端(线程安全的懒加载单例)"""
|
|
|
|
|
if cls._client is not None:
|
|
|
|
|
return cls._client
|
|
|
|
|
|
|
|
|
|
with cls._lock:
|
|
|
|
|
# 双重检查,防止并发时重复创建
|
|
|
|
|
if cls._client is not None:
|
|
|
|
|
return cls._client
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "")
|
|
|
|
|
if parsed_url.scheme in {"http", "https"}:
|
|
|
|
|
hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
|
|
|
|
use_https = parsed_url.scheme == "https"
|
|
|
|
|
else:
|
|
|
|
|
hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
|
|
|
|
|
use_https = False
|
|
|
|
|
|
|
|
|
|
client_config = {
|
|
|
|
|
"hosts": [hosts],
|
|
|
|
|
"basic_auth": (
|
|
|
|
|
os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
|
|
|
|
os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
|
|
|
|
),
|
|
|
|
|
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
|
|
|
|
|
"retry_on_timeout": True,
|
|
|
|
|
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
|
|
|
|
|
"connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if use_https:
|
|
|
|
|
client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true"
|
|
|
|
|
ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS")
|
|
|
|
|
if ca_certs:
|
|
|
|
|
client_config["ca_certs"] = str(ca_certs)
|
|
|
|
|
|
|
|
|
|
client = Elasticsearch(**client_config)
|
|
|
|
|
|
|
|
|
|
if not client.ping():
|
|
|
|
|
raise ConnectionError("Failed to connect to Elasticsearch")
|
|
|
|
|
|
|
|
|
|
# 版本检查只做一次
|
|
|
|
|
if not cls._version_checked:
|
|
|
|
|
info = client.info()
|
|
|
|
|
version = info["version"]["number"]
|
|
|
|
|
if parse_version(version) < parse_version("8.0.0"):
|
|
|
|
|
raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}")
|
|
|
|
|
cls._version_checked = True
|
|
|
|
|
logger.info(f"Elasticsearch shared client initialized, version: {version}")
|
|
|
|
|
|
|
|
|
|
cls._client = client
|
|
|
|
|
|
|
|
|
|
except requests.ConnectionError as e:
|
|
|
|
|
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
return cls._client
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector:
|
|
|
|
|
"""创建向量服务实例(共享 ES 连接)"""
|
|
|
|
|
client = cls._get_shared_client()
|
|
|
|
|
collection_name = f"Vector_index_{knowledge.id}_Node"
|
|
|
|
|
|
|
|
|
|
# Use regular Elasticsearch with config values
|
|
|
|
|
config_dict = {
|
|
|
|
|
"host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"),
|
|
|
|
|
"port": os.getenv("ELASTICSEARCH_PORT", 9200),
|
|
|
|
|
"username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
|
|
|
|
|
"password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Common configuration
|
|
|
|
|
config_dict.update(
|
|
|
|
|
{
|
|
|
|
|
"ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None,
|
|
|
|
|
"verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true",
|
|
|
|
|
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
|
|
|
|
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
|
|
|
|
|
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if knowledge.embedding is None:
|
|
|
|
|
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
|
|
|
|
if knowledge.reranker is None:
|
|
|
|
|
@@ -775,9 +748,9 @@ class ElasticSearchVectorFactory:
|
|
|
|
|
|
|
|
|
|
return ElasticSearchVector(
|
|
|
|
|
index_name=collection_name,
|
|
|
|
|
config=ElasticSearchConfig(**config_dict),
|
|
|
|
|
client=client,
|
|
|
|
|
embedding_config=knowledge.embedding.api_keys[0],
|
|
|
|
|
reranker_config=knowledge.reranker.api_keys[0]
|
|
|
|
|
reranker_config=knowledge.reranker.api_keys[0],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|