[ADD]Support graph search
This commit is contained in:
@@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -389,36 +392,41 @@ async def retrieve_chunks(
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
private_items = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
private_kb_ids = [item[0] for item in private_items]
|
||||
private_workspace_ids = [item[1] for item in private_items]
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
items = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
if share_ids:
|
||||
if items:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
if not existing_ids:
|
||||
share_kb_ids = [item[0] for item in share_items]
|
||||
share_workspace_ids = [item[1] for item in share_items]
|
||||
private_kb_ids.extend(share_kb_ids)
|
||||
private_workspace_ids.extend(share_workspace_ids)
|
||||
if not private_kb_ids:
|
||||
return success(data=[], msg="retrieval successful")
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
kb_id = private_kb_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -448,4 +456,21 @@ async def retrieve_chunks(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
Reference in New Issue
Block a user