Merge pull request #308 from SuanmoSuanyangTechnology/fix/release_memory_bug
Fix/release memory bug
This commit is contained in:
@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
|
|||||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def knowledge_retrieval(
|
def knowledge_retrieval(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -62,7 +64,15 @@ def knowledge_retrieval(
|
|||||||
merge_strategy = config.get("merge_strategy", "weight")
|
merge_strategy = config.get("merge_strategy", "weight")
|
||||||
reranker_id = config.get("reranker_id")
|
reranker_id = config.get("reranker_id")
|
||||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
# use_graph = config.get("use_graph", "false").lower() == "true"
|
||||||
|
|
||||||
|
use_graph_value = config.get("use_graph", False)
|
||||||
|
if isinstance(use_graph_value, bool):
|
||||||
|
use_graph = use_graph_value
|
||||||
|
elif isinstance(use_graph_value, str):
|
||||||
|
use_graph = use_graph_value.lower() in ("true", "1", "yes")
|
||||||
|
else:
|
||||||
|
use_graph = False
|
||||||
|
|
||||||
file_names_filter = []
|
file_names_filter = []
|
||||||
if user_ids:
|
if user_ids:
|
||||||
@@ -159,13 +169,29 @@ def knowledge_retrieval(
|
|||||||
|
|
||||||
# Use the specified reranker for re-ranking
|
# Use the specified reranker for re-ranking
|
||||||
if reranker_id:
|
if reranker_id:
|
||||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
try:
|
||||||
# use graph
|
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||||
|
except Exception as rerank_error:
|
||||||
|
# If reranker fails, log warning and continue with original results
|
||||||
|
logger.warning(
|
||||||
|
"Reranker failed, falling back to original results",
|
||||||
|
extra={
|
||||||
|
"reranker_id": reranker_id,
|
||||||
|
"query": query,
|
||||||
|
"doc_count": len(all_results),
|
||||||
|
"error": str(rerank_error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if use_graph:
|
if use_graph:
|
||||||
from app.core.rag.common.settings import kg_retriever
|
try:
|
||||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
from app.core.rag.common.settings import kg_retriever
|
||||||
if doc:
|
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
all_results.insert(0, doc)
|
if doc:
|
||||||
|
all_results.insert(0, doc)
|
||||||
|
except Exception as graph_error:
|
||||||
|
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user