diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index cc1f8c98..3012d159 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -457,7 +457,7 @@ async def retrieve_chunks( if doc.metadata["doc_id"] not in seen_ids: seen_ids.add(doc.metadata["doc_id"]) unique_rs.append(doc) - rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) + rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else [] if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: kb_ids = [str(kb_id) for kb_id in private_kb_ids] workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 61540ee4..4b99986b 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -113,7 +113,7 @@ def knowledge_retrieval( continue # Use the specified reranker for re-ranking - if reranker_id: + if reranker_id and all_results: try: all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) except Exception as rerank_error: diff --git a/api/app/core/rag/utils/es_conn.py b/api/app/core/rag/utils/es_conn.py index 7fbf0e38..9a0edd24 100644 --- a/api/app/core/rag/utils/es_conn.py +++ b/api/app/core/rag/utils/es_conn.py @@ -68,9 +68,9 @@ class ESConnection(DocStoreConnection): client_config = { "hosts": [hosts], "basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")), - "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), + "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)), "retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true", - "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)), + "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)), } # Only add SSL settings if using HTTPS diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 386920e0..cc9ec120 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -1,25 +1,22 @@ import os import logging -from typing import Any, cast +import threading +from typing import Any from urllib.parse import urlparse -import uuid import requests from elasticsearch import Elasticsearch, helpers from elasticsearch.helpers import BulkIndexError from packaging.version import parse as parse_version -from pydantic import BaseModel, model_validator -from abc import ABC # langchain-community # langchain-xinference # from langchain_community.embeddings import XinferenceEmbeddings # from langchain_xinference import XinferenceRerank from langchain_core.documents import Document from app.core.models.base import RedBearModelConfig -from app.core.models import RedBearLLM, RedBearRerank +from app.core.models import RedBearRerank from app.core.models.embedding import RedBearEmbeddings -from app.models.models_model import ModelConfig, ModelApiKey -from app.services.model_service import ModelConfigService +from app.models.models_model import ModelApiKey from app.models.knowledge_model import Knowledge from app.core.rag.vdb.field import Field @@ -29,37 +26,9 @@ from app.core.rag.models.chunk import DocumentChunk logger = logging.getLogger(__name__) -class ElasticSearchConfig(BaseModel): - # Regular Elasticsearch config - host: str | None = None - port: int | None = None - username: str | None = None - password: str | None = None - - # Common config - ca_certs: str | None = None - verify_certs: bool = False - request_timeout: int = 100000 - retry_on_timeout: bool = True - max_retries: int = 10000 - - @model_validator(mode="before") - @classmethod - def validate_config(cls, values: dict): - # Regular Elasticsearch validation - if not values.get("host"): - raise ValueError("config HOST is required for regular Elasticsearch") - if not values.get("port"): - raise ValueError("config PORT is required for regular Elasticsearch") - if not values.get("username"): - raise ValueError("config USERNAME is required for regular Elasticsearch") - if not values.get("password"): - raise ValueError("config PASSWORD is required for regular Elasticsearch") - return values - - class ElasticSearchVector(BaseVector): - def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): + def __init__(self, index_name: str, client: Elasticsearch, + embedding_config: ModelApiKey, reranker_config: ModelApiKey): super().__init__(index_name.lower()) # 初始化 Embedding 模型(自动支持火山引擎多模态) @@ -77,58 +46,8 @@ class ElasticSearchVector(BaseVector): api_key=reranker_config.api_key, base_url=reranker_config.api_base )) - self._client = self._init_client(config) - self._version = self._get_version() - self._check_version() - - def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: - """ - Initialize Elasticsearch client for regular Elasticsearch. - """ - try: - # Regular Elasticsearch configuration - parsed_url = urlparse(config.host or "") - if parsed_url.scheme in {"http", "https"}: - hosts = f"{config.host}:{config.port}" - use_https = parsed_url.scheme == "https" - else: - hosts = f"https://{config.host}:{config.port}" - use_https = False - - client_config = { - "hosts": [hosts], - "basic_auth": (config.username, config.password), - "request_timeout": config.request_timeout, - "retry_on_timeout": config.retry_on_timeout, - "max_retries": config.max_retries, - } - - # Only add SSL settings if using HTTPS - if use_https: - client_config["verify_certs"] = config.verify_certs - if config.ca_certs: - client_config["ca_certs"] = config.ca_certs - - client = Elasticsearch(**client_config) - - # Test connection - if not client.ping(): - raise ConnectionError("Failed to connect to Elasticsearch") - - except requests.ConnectionError as e: - raise ConnectionError(f"Vector database connection error: {str(e)}") - except Exception as e: - raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") - - return client - - def _get_version(self) -> str: - info = self._client.info() - return cast(str, info["version"]["number"]) - - def _check_version(self): - if parse_version(self._version) < parse_version("8.0.0"): - raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + # 使用外部传入的共享客户端 + self._client = client def get_type(self) -> str: return "elasticsearch" @@ -745,29 +664,79 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory: - @staticmethod - def init_vector(knowledge: Knowledge) -> ElasticSearchVector: + """ES 向量服务工厂 - 单例共享连接""" + + _client: Elasticsearch | None = None + _lock = threading.Lock() + _version_checked = False + + @classmethod + def _get_shared_client(cls) -> Elasticsearch: + """获取共享的 ES 客户端(线程安全的懒加载单例)""" + if cls._client is not None: + return cls._client + + with cls._lock: + # 双重检查,防止并发时重复创建 + if cls._client is not None: + return cls._client + + try: + parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "") + if parsed_url.scheme in {"http", "https"}: + hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}' + use_https = parsed_url.scheme == "https" + else: + hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}' + use_https = False + + client_config = { + "hosts": [hosts], + "basic_auth": ( + os.getenv("ELASTICSEARCH_USERNAME", "elastic"), + os.getenv("ELASTICSEARCH_PASSWORD", "elastic"), + ), + "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)), + "retry_on_timeout": True, + "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)), + "connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)), + } + + if use_https: + client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true" + ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS") + if ca_certs: + client_config["ca_certs"] = str(ca_certs) + + client = Elasticsearch(**client_config) + + if not client.ping(): + raise ConnectionError("Failed to connect to Elasticsearch") + + # 版本检查只做一次 + if not cls._version_checked: + info = client.info() + version = info["version"]["number"] + if parse_version(version) < parse_version("8.0.0"): + raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}") + cls._version_checked = True + logger.info(f"Elasticsearch shared client initialized, version: {version}") + + cls._client = client + + except requests.ConnectionError as e: + raise ConnectionError(f"Vector database connection error: {str(e)}") + except Exception as e: + raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") + + return cls._client + + @classmethod + def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector: + """创建向量服务实例(共享 ES 连接)""" + client = cls._get_shared_client() collection_name = f"Vector_index_{knowledge.id}_Node" - # Use regular Elasticsearch with config values - config_dict = { - "host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"), - "port": os.getenv("ELASTICSEARCH_PORT", 9200), - "username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"), - "password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"), - } - - # Common configuration - config_dict.update( - { - "ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None, - "verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true", - "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), - "retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true", - "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)), - } - ) - if knowledge.embedding is None: raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") if knowledge.reranker is None: @@ -775,9 +744,9 @@ class ElasticSearchVectorFactory: return ElasticSearchVector( index_name=collection_name, - config=ElasticSearchConfig(**config_dict), + client=client, embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] + reranker_config=knowledge.reranker.api_keys[0], )