[ADD]Support graph search

This commit is contained in:
lixiangcheng1
2025-12-30 11:53:16 +08:00
parent 6defcaf982
commit 0078028992
6 changed files with 59 additions and 32 deletions

View File

@@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.common.settings import kg_retriever
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
@@ -389,36 +392,41 @@ async def retrieve_chunks(
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_service.get_chunded_knowledgeids(
private_items = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
private_kb_ids = [item[0] for item in private_items]
private_workspace_ids = [item[1] for item in private_items]
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.get_chunded_knowledgeids(
items = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
if share_ids:
if items:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
]
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters,
current_user=current_user
)
existing_ids.extend(items)
if not existing_ids:
share_kb_ids = [item[0] for item in share_items]
share_workspace_ids = [item[1] for item in share_items]
private_kb_ids.extend(share_kb_ids)
private_workspace_ids.extend(share_workspace_ids)
if not private_kb_ids:
return success(data=[], msg="retrieval successful")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
kb_id = private_kb_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
@@ -448,4 +456,21 @@ async def retrieve_chunks(
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
rs.insert(0, doc)
return success(data=rs, msg="retrieval successful")