[ADD]Add functions related to knowledge base graph:

Add functions related to knowledge base graph:
1. Entity type generation,
2. Knowledge base graph acquisition,
3. Hard deletion of knowledge base graph,
4. Knowledge base graph reconstruction (asynchronous)
This commit is contained in:
lixiangcheng1
2025-12-27 13:53:10 +08:00
parent 06f64809c3
commit a0c362244e
35 changed files with 6267 additions and 143 deletions

View File

@@ -1,5 +1,6 @@
from typing import Optional
import datetime
import json
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy import or_
@@ -13,8 +14,13 @@ from app.schemas import knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service
from app.core.rag.llm.chat_model import Base
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.common import settings
from app.celery_app import celery_app
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -306,3 +312,171 @@ async def delete_knowledge(
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
async def get_knowledge_graph(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve knowledge_graph base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
req = {
"kb_id": [str(db_knowledge.id)],
"knowledge_graph_kwd": ["graph"]
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(str(db_knowledge.workspace_id)), str(db_knowledge.id)):
return success(data=obj, msg="Successfully obtained knowledge graph information")
sres = settings.retriever.search(req, search.index_name(str(db_knowledge.workspace_id)), [str(db_knowledge.id)])
if not len(sres.ids):
return success(data=obj, msg="Successfully obtained knowledge graph information")
for id in sres.ids[:1]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["page_content"])
except Exception:
continue
obj[ty] = content_json
if "nodes" in obj["graph"]:
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
return success(data=obj, msg="Successfully obtained knowledge graph information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Knowledge graph query failed: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.delete("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
async def delete_knowledge_graph(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Soft-delete knowledge graph
"""
api_logger.info(f"Request to delete knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. delete knowledge graph
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(str(db_knowledge.workspace_id)), str(db_knowledge.id))
api_logger.info(f"The knowledge graph has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
return success(msg="The knowledge graph has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.post("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
async def rebuild_knowledge_graph(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
rebuild knowledge graph
"""
api_logger.info(f"Request to rebuild knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. delete knowledge graph
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(str(db_knowledge.workspace_id)), str(db_knowledge.id))
# 3. build knowledge graph
# from app.tasks import build_graphrag_for_kb
# build_graphrag_for_kb(kb_id)
task = celery_app.send_task("app.core.rag.tasks.build_graphrag_for_kb", args=[knowledge_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. rebuild knowledge graph is being processed in the background.")
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/{knowledge_id}/knowledge_graph_entity_types", response_model=ApiResponse)
async def get_knowledge_graph_entity_types(
knowledge_id: uuid.UUID,
scenario: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
get knowledge graph entity types based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. Prepare to configure chat_mdl information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
response = graph_entity_types(chat_model, scenario)
return success(data=response, msg="Successfully obtained knowledge graph entity types")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"get knowledge graph entity types failed: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -0,0 +1,13 @@
import os
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton

View File

@@ -0,0 +1,4 @@
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg

View File

@@ -0,0 +1,13 @@
import os
import logging
def log_exception(e, *args):
logging.exception(e)
for a in args:
if hasattr(a, "text"):
logging.error(a.text)
raise Exception(a.text)
else:
logging.error(str(a))
raise e

View File

@@ -1,2 +1,24 @@
from app.core.rag.utils import es_conn
from app.core.rag.nlp import search
from app.core.rag.graphrag import search as kg_search
PARALLEL_DEVICES: int = 0
docStoreConn = None
retriever = None
kg_retriever = None
def init_settings():
global docStoreConn, retriever, kg_retriever
if docStoreConn is None:
docStoreConn = es_conn.ESConnection()
if retriever is None:
retriever = search.Dealer(docStoreConn)
if kg_retriever is None:
kg_retriever = kg_search.KGSearch(docStoreConn)
init_settings()

View File

@@ -0,0 +1,240 @@
import logging
import itertools
import os
import re
from dataclasses import dataclass
from typing import Any, Callable
import networkx as nx
import trio
from app.core.rag.graphrag.general.extractor import Extractor
from app.core.rag.nlp import is_english
import editdistance
from app.core.rag.graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from app.core.rag.llm.chat_model import Base as CompletionLLM
from app.core.rag.graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange, has_canceled
from app.core.rag.common.exceptions import TaskCanceledException
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
@dataclass
class EntityResolutionResult:
"""Entity resolution result class definition."""
graph: nx.Graph
change: GraphChange
class EntityResolution(Extractor):
"""Entity resolution class definition."""
_resolution_prompt: str
_output_formatter_prompt: str
_record_delimiter_key: str
_entity_index_delimiter_key: str
_resolution_result_delimiter_key: str
def __init__(
self,
llm_invoker: CompletionLLM,
):
super().__init__(llm_invoker)
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
self._record_delimiter_key = "record_delimiter"
self._entity_index_delimiter_key = "entity_index_delimiter"
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
async def __call__(self, graph: nx.Graph,
subgraph_nodes: set[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None,
task_id: str = "") -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
# Wire defaults into the prompt variables
self.prompt_variables = {
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._entity_index_delimiter_key: prompt_variables.get(self._entity_index_delimiter_key)
or DEFAULT_ENTITY_INDEX_DELIMITER,
self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
}
nodes = sorted(graph.nodes())
entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}
for node in nodes:
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
callback(msg=f"Identified {num_candidates} candidate pairs")
remain_candidates_to_resolve = num_candidates
resolution_result = set()
resolution_result_lock = trio.Lock()
resolution_batch_size = 100
max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks)
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback
async with semaphore:
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
except Exception as e:
logging.error(f"Error resolving candidate batch: {e}")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
change = GraphChange()
connect_graph = nx.Graph()
connect_graph.add_edges_from(resolution_result)
async def limited_merge_nodes(graph, nodes, change):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
# Update pagerank
pr = nx.pagerank(graph)
for node_name, pagerank in pr.items():
graph.nodes[node_name]["pagerank"] = pagerank
return EntityResolutionResult(
graph=graph,
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**self.prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter:
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
return
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_delimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
async with resolution_result_lock:
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
def _process_results(
self,
records_length: int,
results: str,
record_delimiter: str,
entity_index_delimiter: str,
resolution_result_delimiter: str
) -> list:
ans_list = []
records = [r.strip() for r in results.split(record_delimiter)]
for record in records:
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
match_int = re.search(pattern_int, record)
res_int = int(str(match_int.group(1) if match_int else '0'))
if res_int > records_length:
continue
pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
match_bool = re.search(pattern_bool, record)
res_bool = str(match_bool.group(1) if match_bool else '')
if res_int and res_bool:
if res_bool.lower() == 'yes':
ans_list.append((res_int, "yes"))
return ans_list
def _has_digit_in_2gram_diff(self, a, b):
def to_2gram_set(s):
return {s[i:i+2] for i in range(len(s) - 1)}
set_a = to_2gram_set(a)
set_b = to_2gram_set(b)
diff = set_a ^ set_b
return any(any(c.isdigit() for c in pair) for pair in diff)
def is_similarity(self, a, b):
if self._has_digit_in_2gram_diff(a, b):
return False
if is_english(a) and is_english(b):
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
return True
return False
a, b = set(a), set(b)
max_l = max(len(a), len(b))
if max_l < 4:
return len(a & b) > 1
return len(a & b)*1./max_l >= 0.8

View File

@@ -0,0 +1,58 @@
ENTITY_RESOLUTION_PROMPT = """
-Goal-
Please answer the following Question as required
-Steps-
1. Identify each line of questioning as required
2. Return output in English as a single list of each line answer in steps 1. Use **{record_delimiter}** as the list delimiter.
######################
-Examples-
######################
Example 1:
Question:
When determining whether two Products are the same, you should only focus on critical properties and overlook noisy factors.
Demonstration 1: name of Product A is : "computer", name of Product B is :"phone" No, Product A and Product B are different products.
Question 1: name of Product A is : "television", name of Product B is :"TV"
Question 2: name of Product A is : "cup", name of Product B is :"mug"
Question 3: name of Product A is : "soccer", name of Product B is :"football"
Question 4: name of Product A is : "pen", name of Product B is :"eraser"
Use domain knowledge of Products to help understand the text and answer the above 4 questions in the format: For Question i, Yes, Product A and Product B are the same product. or No, Product A and Product B are different products. For Question i+1, (repeat the above procedures)
################
Output:
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, Product A and Product B are the same product.){record_delimiter}
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, Product A and Product B are different products.){record_delimiter}
#############################
Example 2:
Question:
When determining whether two toponym are the same, you should only focus on critical properties and overlook noisy factors.
Demonstration 1: name of toponym A is : "nanjing", name of toponym B is :"nanjing city" No, toponym A and toponym B are same toponym.
Question 1: name of toponym A is : "Chicago", name of toponym B is :"ChiTown"
Question 2: name of toponym A is : "Shanghai", name of toponym B is :"Zhengzhou"
Question 3: name of toponym A is : "Beijing", name of toponym B is :"Peking"
Question 4: name of toponym A is : "Los Angeles", name of toponym B is :"Cleveland"
Use domain knowledge of toponym to help understand the text and answer the above 4 questions in the format: For Question i, Yes, toponym A and toponym B are the same toponym. or No, toponym A and toponym B are different toponym. For Question i+1, (repeat the above procedures)
################
Output:
(For question {entity_index_delimiter}1{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are same toponym.){record_delimiter}
(For question {entity_index_delimiter}2{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
(For question {entity_index_delimiter}3{entity_index_delimiter}, {resolution_result_delimiter}yes{resolution_result_delimiter}, toponym A and toponym B are the same toponym.){record_delimiter}
(For question {entity_index_delimiter}4{entity_index_delimiter}, {resolution_result_delimiter}no{resolution_result_delimiter}, toponym A and toponym B are different toponym.){record_delimiter}
#############################
-Real Data-
######################
Question:{input_text}
######################
Output:
"""

View File

@@ -0,0 +1,158 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/community_report.py)
"""
COMMUNITY_REPORT_PROMPT = """
You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
# Goal
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.
# Report Structure
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
]
}}
# Grounding Rules
Points supported by data should list their data references as follows:
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
Do not include information where the supporting evidence for it is not provided.
# Example Input
-----------
Text:
-Entities-
id,entity,description
5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March
6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza
-Relationships-
id,source,target,description
37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March
38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza
39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza
40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza
41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march
43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March
Output:
{{
"title": "Verdant Oasis Plaza and Unity March",
"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.",
"rating": 5.0,
"rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.",
"findings": [
{{
"summary": "Verdant Oasis Plaza as the central location",
"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]"
}},
{{
"summary": "Harmony Assembly's role in the community",
"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]"
}},
{{
"summary": "Unity March as a significant event",
"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]"
}},
{{
"summary": "Role of Tribune Spotlight",
"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]"
}}
]
}}
# Real Data
Use the following text for your answer. Do not make anything up in your answer.
Text:
-Entities-
{entity_df}
-Relationships-
{relation_df}
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
{{
"title": <report_title>,
"summary": <executive_summary>,
"rating": <impact_severity_rating>,
"rating_explanation": <rating_explanation>,
"findings": [
{{
"summary":<insight_1_summary>,
"explanation": <insight_1_explanation>
}},
{{
"summary":<insight_2_summary>,
"explanation": <insight_2_explanation>
}}
]
}}
# Grounding Rules
Points supported by data should list their data references as follows:
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
For example:
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
Do not include information where the supporting evidence for it is not provided.
Output:"""

View File

@@ -0,0 +1,178 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import json
import os
import re
from typing import Callable
from dataclasses import dataclass
import networkx as nx
import pandas as pd
from app.core.rag.common.exceptions import TaskCanceledException
from app.core.rag.common.connection_utils import timeout
from app.core.rag.graphrag.general import leiden
from app.core.rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from app.core.rag.graphrag.general.extractor import Extractor
from app.core.rag.graphrag.general.leiden import add_community_info2graph
from app.core.rag.llm.chat_model import Base as CompletionLLM
from app.core.rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter, has_canceled
from app.core.rag.common.token_utils import num_tokens_from_string
import trio
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
output: list[str]
structured_output: list[dict]
class CommunityReportsExtractor(Extractor):
"""Community reports extractor class definition."""
_extraction_prompt: str
_output_formatter_prompt: str
_max_report_length: int
def __init__(
self,
llm_invoker: CompletionLLM,
max_report_length: int | None = None,
):
super().__init__(llm_invoker)
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
communities: dict[int, dict[str, dict]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
res_str = []
res_dict = []
over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during community report extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cm_id, cm = community
weight = cm["weight"]
ents = cm["nodes"]
if len(ents) < 2:
return
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
ent_df = pd.DataFrame(ent_list)
rela_list = []
k = 0
for i in range(0, len(ents)):
if k >= 10000:
break
for j in range(i + 1, len(ents)):
if k >= 10000:
break
edge = graph.get_edge_data(ents[i], ents[j])
if edge is None:
continue
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
k += 1
rela_df = pd.DataFrame(rela_list)
prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"),
"relation_df": rela_df.to_csv(index_label="id")
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
async with chat_limiter:
try:
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}")
return
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
try:
response = json.loads(response)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON response: {e}")
logging.error(f"Response content: {response}")
return
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
return
response["weight"] = weight
response["entities"] = ents
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
st = trio.current_time()
async with trio.open_nursery() as nursery:
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items():
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(extract_community_report, community)
if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,
output=res_str,
)
def _get_text_output(self, parsed_output: dict) -> str:
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])
def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")
def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")
report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
from typing import Any
import numpy as np
import networkx as nx
from dataclasses import dataclass
from app.core.rag.graphrag.general.leiden import stable_largest_connected_component
import graspologic as gc
@dataclass
class NodeEmbeddings:
"""Node embeddings class definition."""
nodes: list[str]
embeddings: np.ndarray
def embed_node2vec(
graph: nx.Graph | nx.DiGraph,
dimensions: int = 1536,
num_walks: int = 10,
walk_length: int = 40,
window_size: int = 2,
iterations: int = 3,
random_seed: int = 86,
) -> NodeEmbeddings:
"""Generate node embeddings using Node2Vec."""
# generate embedding
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
graph=graph,
dimensions=dimensions,
window_size=window_size,
iterations=iterations,
num_walks=num_walks,
walk_length=walk_length,
random_seed=random_seed,
)
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
def run(graph: nx.Graph, args: dict[str, Any]) -> dict:
"""Run method definition."""
if args.get("use_lcc", True):
graph = stable_largest_connected_component(graph)
# create graph embedding using node2vec
embeddings = embed_node2vec(
graph=graph,
dimensions=args.get("dimensions", 1536),
num_walks=args.get("num_walks", 10),
walk_length=args.get("walk_length", 40),
window_size=args.get("window_size", 2),
iterations=args.get("iterations", 3),
random_seed=args.get("random_seed", 86),
)
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
sorted_pairs = sorted(pairs, key=lambda x: x[0])
return dict(sorted_pairs)

View File

@@ -0,0 +1,300 @@
import logging
import os
import re
from collections import Counter, defaultdict
from copy import deepcopy
from typing import Callable
import networkx as nx
import trio
from app.core.rag.common.connection_utils import timeout
from app.core.rag.common.token_utils import truncate
from app.core.rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from app.core.rag.graphrag.utils import (
GraphChange,
chat_limiter,
flat_uniq_list,
get_from_to,
get_llm_cache,
handle_single_entity_extraction,
handle_single_relationship_extraction,
set_llm_cache,
split_string_by_multi_markers,
has_canceled,
)
from app.core.rag.llm.chat_model import Base as CompletionLLM
from app.core.rag.prompts.generator import message_fit_in
from app.core.rag.common.exceptions import TaskCanceledException
GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 2
MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = int(os.environ.get("MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK", 10))
class Extractor:
_llm: CompletionLLM
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
):
self._llm = llm_invoker
self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
@timeout(60 * 20)
def _chat(self, system, history, gen_conf={}, task_id=""):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.model_name, system, hist, conf)
if response:
return response
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(getattr(self._llm, 'max_length', 8096) * 0.92))
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = self._llm.chat(system_msg[0]["content"], hist, conf)
if isinstance(response, tuple):
response = response[0]
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.model_name, system, response, history, gen_conf)
except Exception as e:
logging.exception(e)
if attempt == 2:
raise
return response
def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
ent_types = [t.lower() for t in self._entity_types]
for record in records:
record_attributes = split_string_by_multi_markers(record, [tuple_delimiter])
if_entities = handle_single_entity_extraction(record_attributes, chunk_key)
if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = handle_single_relationship_extraction(record_attributes, chunk_key)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)
return dict(maybe_nodes), dict(maybe_edges)
async def __call__(self, document_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback
start_ts = trio.current_time()
async def extract_all(document_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
limiter = trio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count
async with limiter:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction")
try:
await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id)
except Exception as e:
error_count += 1
error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}"
logging.warning(error_msg)
if self.callback:
self.callback(msg=error_msg)
if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (document_id, ck), i, len(chunks), task_id)
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
logging.warning(warning_msg)
if self.callback:
self.callback(msg=warning_msg)
return out_results
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction")
out_results = await extract_all(document_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction")
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
sum_token_count = 0
for m_nodes, m_edges, token_count in out_results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count
now = trio.current_time()
if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
start_ts = now
logging.info("Relationships merging...")
all_relationships_data = []
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time()
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
if not len(all_entities_data) and not len(all_relationships_data):
logging.warning("Didn't extract any entities and relationships, maybe your LLM is not working")
if not len(all_entities_data):
logging.warning("Didn't extract any entities")
if not len(all_relationships_data):
logging.warning("Didn't extract any relationships")
return all_entities_data, all_relationships_data
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes")
if not entities:
return
entity_type = sorted(
Counter([dp["entity_type"] for dp in entities]).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities])))
already_source_ids = flat_uniq_list(entities, "source_id")
description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
all_relationships_data.append(node_data)
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""):
if not edges_data:
return
weight = sum([edge["weight"] for edge in edges_data])
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id)
keywords = flat_uniq_list(edges_data, "keywords")
source_id = flat_uniq_list(edges_data, "source_id")
edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id)
all_relationships_data.append(edge_data)
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""):
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes")
if len(nodes) <= 1:
return
change.added_updated_nodes.add(nodes[0])
change.removed_nodes.update(nodes[1:])
nodes_set = set(nodes)
node0_attrs = graph.nodes[nodes[0]]
node0_neighbors = set(graph.neighbors(nodes[0]))
for node1 in nodes[1:]:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes")
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
node1_attrs = graph.nodes[node1]
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
node0_attrs["source_id"] = sorted(set(node0_attrs["source_id"] + node1_attrs["source_id"]))
for neighbor in graph.neighbors(node1):
change.removed_edges.add(get_from_to(node1, neighbor))
if neighbor not in nodes_set:
edge1_attrs = graph.get_edge_data(node1, neighbor)
if neighbor in node0_neighbors:
# Merge two edges
change.added_updated_edges.add(get_from_to(nodes[0], neighbor))
edge0_attrs = graph.get_edge_data(nodes[0], neighbor)
edge0_attrs["weight"] += edge1_attrs["weight"]
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
for attr in ["keywords", "source_id"]:
edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id)
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
else:
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
graph.remove_node(node1)
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id)
graph.nodes[nodes[0]].update(node0_attrs)
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens)
description_list = use_description.split(GRAPH_FIELD_SEP)
if len(description_list) <= 12:
return use_description
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
context_base = dict(
entity_name=entity_or_relation_name,
description_list=description_list,
language=self._language,
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@@ -0,0 +1,150 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import re
from typing import Any
from dataclasses import dataclass
import tiktoken
import trio
from app.core.rag.graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from app.core.rag.graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from app.core.rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers
from app.core.rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from app.core.rag.common.token_utils import num_tokens_from_string
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
"""Unipartite graph extractor class definition."""
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
join_descriptions=True,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
super().__init__(llm_invoker, language, entity_types)
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
# Construct the looping arguments
encoding = tiktoken.get_encoding("cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
# Wire defaults into the prompt variables
self._prompt_variables = {
self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(entity_types),
}
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
variables = {
**self._prompt_variables,
self._input_text_key: content,
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, {}))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break
history.append({"role": "assistant", "content": "Y"})
records = split_string_by_multi_markers(
results,
[self._prompt_variables[self._record_delimiter_key], self._prompt_variables[self._completion_delimiter_key]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._prompt_variables[self._tuple_delimiter_key])
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

View File

@@ -0,0 +1,124 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py)
"""
GRAPH_EXTRACTION_PROMPT = """
-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized, in language of 'Text'
- entity_type: One of the following types: [{entity_types}]
- entity_description: Comprehensive description of the entity's attributes and activities in language of 'Text'
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other in language of 'Text'
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3. Return output as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
4. When finished, output {completion_delimiter}
######################
-Examples-
######################
Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter}
#############################
Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter}
#############################
Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:"""
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
LOOP_PROMPT = "It appears some entities may have still been missed. Answer Y if there are still entities that need to be added, or N if there are none. Please answer with a single letter Y or N.\n"
SUMMARIZE_DESCRIPTIONS_PROMPT = """
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
Use {language} as output language.
#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
"""

View File

@@ -0,0 +1,554 @@
import json
import logging
import os
import networkx as nx
import trio
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVector
from app.core.rag.common.exceptions import TaskCanceledException
from app.core.rag.common.misc_utils import get_uuid
from app.core.rag.common.connection_utils import timeout
from app.core.rag.graphrag.entity_resolution import EntityResolution
from app.core.rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
from app.core.rag.graphrag.general.extractor import Extractor
from app.core.rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from app.core.rag.graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from app.core.rag.graphrag.utils import (
GraphChange,
chunk_id,
does_graph_contains,
get_graph,
graph_merge,
set_graph,
tidy_graph,
has_canceled,
)
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.utils.redis_conn import RedisDistributedLock
from app.core.rag.common import settings
def init_graphrag(row, vector_size: int):
idxnm = search.index_name(row["workspace_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
chunks = []
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
chunks.append(d["page_content"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
subgraph = await generate_subgraph(
LightKGExt if "method" not in row["parser_config"].get("graphrag", {}) or row["parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt,
workspace_id,
kb_id,
document_id,
chunks,
language,
row["parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if not subgraph:
return
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=document_id, timeout=1200)
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {document_id} graphrag_task_lock acquired")
try:
subgraph_nodes = set(subgraph.nodes())
new_graph = await merge_subgraph(
workspace_id,
kb_id,
document_id,
subgraph,
embedding_model,
callback,
)
assert new_graph is not None
if not with_resolution and not with_community:
return
if with_resolution:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {document_id} graphrag_task_lock acquired")
await resolve_entities(
new_graph,
subgraph_nodes,
workspace_id,
kb_id,
document_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {document_id} graphrag_task_lock acquired")
await extract_community(
new_graph,
workspace_id,
kb_id,
document_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
graphrag_task_lock.release()
now = trio.current_time()
callback(msg=f"GraphRAG for doc {document_id} done in {now - start:.2f} seconds.")
return
async def run_graphrag_for_kb(
row: dict,
document_ids: list[str],
language: str,
parser_config: dict,
vector_service: ElasticSearchVector,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_documents: int = 4,
) -> dict:
workspace_id, kb_id = row["workspace_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
document_ids = list(dict.fromkeys(document_ids)) # Remove duplicate elements
if not document_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable document_id.")
return {"ok_documents": [], "failed_documents": [], "total_documents": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(document_id: str) -> list[str]:
from app.core.rag.common.token_utils import num_tokens_from_string
chunks = []
current_chunk = ""
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
for doc in items:
content = doc.page_content
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_document_chunks: dict[str, list[str]] = {}
total_chunks = 0
for document_id in document_ids:
chunks = load_doc_chunks(document_id)
all_document_chunks[document_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_documents": [], "failed_documents": document_ids, "total_documents": len(document_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_documents)
subgraphs: dict[str, object] = {}
failed_documents: list[tuple[str, str]] = [] # (document_id, error)
async def build_one(document_id: str):
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
chunks = all_document_chunks.get(document_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{document_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in parser_config.get("graphrag", {}) or parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph document:{document_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
workspace_id,
kb_id,
document_id,
chunks,
language,
parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
task_id=row["id"]
)
if sg:
subgraphs[document_id] = sg
callback(msg=f"{msg} done")
else:
failed_documents.append((document_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except TaskCanceledException as canceled:
callback(msg=f"[GraphRAG] build_subgraph document:{document_id} FAILED: {canceled}")
except Exception as e:
failed_documents.append((document_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph document:{document_id} FAILED: {e!r}")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery:
for document_id in document_ids:
nursery.start_soon(build_one, document_id)
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
ok_documents = [d for d in document_ids if d in subgraphs]
if not ok_documents:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_documents": [], "failed_documents": failed_documents, "total_documents": len(document_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
try:
union_nodes: set = set()
final_graph = None
for document_id in ok_documents:
sg = subgraphs[document_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
workspace_id,
kb_id,
document_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_documents)} / total={len(document_ids)}")
return {"ok_documents": ok_documents, "failed_documents": failed_documents, "total_documents": len(document_ids), "total_chunks": total_chunks, "seconds": now - start}
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
workspace_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await extract_community(
final_graph,
workspace_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_documents)} failed={len(failed_documents)} total_documents={len(document_ids)} total_chunks={total_chunks}")
return {
"ok_documents": ok_documents,
"failed_documents": failed_documents, # [(document_id, error), ...]
"total_documents": len(document_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph(
extractor: Extractor,
workspace_id: str,
kb_id: str,
document_id: str,
chunks: list[str],
language,
entity_types,
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during subgraph generation for document {document_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
contains = await does_graph_contains(workspace_id, kb_id, document_id)
if contains:
callback(msg=f"Graph already contains {document_id}")
return None
start = trio.current_time()
ext = extractor(
llm_bdl,
language=language,
entity_types=entity_types,
)
ents, rels = await ext(document_id, chunks, callback, task_id=task_id)
subgraph = nx.Graph()
for ent in ents:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity processing for document {document_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in ent, f"entity {ent} does not have description"
ent["source_id"] = [document_id]
subgraph.add_node(ent["entity_name"], **ent)
ignored_rels = 0
for rel in rels:
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during relationship processing for document {document_id}.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
assert "description" in rel, f"relation {rel} does not have description"
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
ignored_rels += 1
continue
rel["source_id"] = [document_id]
subgraph.add_edge(
rel["src_id"],
rel["tgt_id"],
**rel,
)
if ignored_rels:
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
tidy_graph(subgraph, callback, check_attribute=False)
subgraph.graph["source_id"] = [document_id]
chunk = {
"page_content": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [document_id],
"available_int": 0,
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": document_id}, search.index_name(workspace_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(workspace_id), kb_id)
now = trio.current_time()
callback(msg=f"generated subgraph for document {document_id} in {now - start:.2f} seconds.")
return subgraph
@timeout(60 * 3)
async def merge_subgraph(
workspace_id: str,
kb_id: str,
document_id: str,
subgraph: nx.Graph,
embedding_model,
callback,
):
start = trio.current_time()
change = GraphChange()
old_graph = await get_graph(workspace_id, kb_id, subgraph.graph["source_id"])
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
tidy_graph(old_graph, callback)
new_graph = graph_merge(old_graph, subgraph, change)
else:
new_graph = subgraph
change.added_updated_nodes = set(new_graph.nodes())
change.added_updated_edges = set(new_graph.edges())
pr = nx.pagerank(new_graph)
for node_name, pagerank in pr.items():
new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(workspace_id, kb_id, embedding_model, new_graph, change, callback)
now = trio.current_time()
callback(msg=f"merging subgraph for document {document_id} into the global graph done in {now - start:.2f} seconds.")
return new_graph
@timeout(60 * 30, 1)
async def resolve_entities(
graph,
subgraph_nodes: set[str],
workspace_id: str,
kb_id: str,
document_id: str,
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
# Check if task has been canceled before resolution
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
er = EntityResolution(
llm_bdl,
)
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
graph = reso.graph
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
callback(msg="Graph resolution updated pagerank.")
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(workspace_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@timeout(60 * 30, 1)
async def extract_community(
graph,
workspace_id: str,
kb_id: str,
document_id: str,
llm_bdl,
embed_bdl,
callback,
task_id: str = "",
):
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
)
cr = await ext(graph, callback=callback, task_id=task_id)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
community_structure = cr.structured_output
community_reports = cr.output
document_ids = graph.graph["source_id"]
now = trio.current_time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled during community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
chunks = []
for stru, rep in zip(community_structure, community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
}
chunk = {
"id": get_uuid(),
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"page_content": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]),
"knowledge_graph_kwd": "community_report",
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"source_id": list(document_ids),
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(workspace_id),
kb_id,
)
)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
document_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(workspace_id), kb_id))
if document_store_result:
error_message = f"Insert chunk error: {document_store_result}, please check log file and Elasticsearch status!"
raise Exception(error_message)
if task_id and has_canceled(task_id):
callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports

View File

@@ -0,0 +1,149 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import html
from typing import Any, cast
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component
import networkx as nx
from networkx import is_empty
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Ensure an undirected graph with the same relationships will always be read the same way."""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
# If the graph is undirected, we create the edges in a stable way, so we get the same results
# for example:
# A -> B
# in graph theory is the same as
# B -> A
# in an undirected graph
# however, this can lead to downstream issues because sometimes
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
# but they base some of their logic on the order of the nodes, so the order ends up being important
# so we sort the nodes in the edge in a stable way, so that we always get the same order
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
"""Normalize node names."""
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
return nx.relabel_nodes(graph, node_mapping)
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
graph = normalize_node_names(graph)
return _stabilize_graph(graph)
def _compute_leiden_communities(
graph: nx.Graph | nx.DiGraph,
max_cluster_size: int,
use_lcc: bool,
seed=0xDEADBEEF,
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
results: dict[int, dict[str, int]] = {}
if is_empty(graph):
return results
if use_lcc:
graph = stable_largest_connected_component(graph)
community_mapping = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=seed
)
for partition in community_mapping:
results[partition.level] = results.get(partition.level, {})
results[partition.level][partition.node] = partition.cluster
return results
def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
"""Run method definition."""
max_cluster_size = args.get("max_cluster_size", 12)
use_lcc = args.get("use_lcc", True)
if args.get("verbose", False):
logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
nodes = set(graph.nodes())
if not nodes:
return {}
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=args.get("seed", 0xDEADBEEF),
)
levels = args.get("levels")
# If they don't pass in levels, use them all
if levels is None:
levels = sorted(node_id_to_community_map.keys())
results_by_level: dict[int, dict[str, dict[str, int | list]]] = {}
for level in levels:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
if node_id not in nodes:
logging.warning(f"Node {node_id} not found in the graph.")
continue
community_id = str(raw_community_id)
if community_id not in result:
result[community_id] = {"weight": 0, "nodes": []}
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:
continue
max_weight = max(weights)
if max_weight == 0:
continue
for _, comm in result.items():
comm["weight"] /= max_weight
return results_by_level
def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title):
for n in nodes:
if "communities" not in graph.nodes[n]:
graph.nodes[n]["communities"] = []
graph.nodes[n]["communities"].append(community_title)
graph.nodes[n]["communities"] = list(set(graph.nodes[n]["communities"]))

View File

@@ -0,0 +1,163 @@
import logging
import collections
import re
from typing import Any
from dataclasses import dataclass
import trio
from app.core.rag.graphrag.general.extractor import Extractor
from app.core.rag.graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from app.core.rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from app.core.rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json
from functools import reduce
from app.core.rag.common.token_utils import num_tokens_from_string
@dataclass
class MindMapResult:
"""Unipartite Mind Graph result class definition."""
output: dict
class MindMapExtractor(Extractor):
_input_text_key: str
_mind_map_prompt: str
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
prompt: str | None = None,
input_text_key: str | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._input_text_key = input_text_key or "input_text"
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
def _key(self, k):
return re.sub(r"\*+", "", k)
def _be_children(self, obj: dict, keyset: set):
if isinstance(obj, str):
obj = [obj]
if isinstance(obj, list):
keyset.update(obj)
obj = [re.sub(r"\*+", "", i) for i in obj]
return [{"id": i, "children": []} for i in obj if i]
arr = []
for k, v in obj.items():
k = self._key(k)
if k and k not in keyset:
keyset.add(k)
arr.append(
{
"id": k,
"children": self._be_children(v, keyset)
}
)
return arr
async def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
res = []
token_count = max(getattr(self._llm, 'max_length', 8096) * 0.8, getattr(self._llm, 'max_length', 8096) - 512)
texts = []
cnt = 0
async with trio.open_nursery() as nursery:
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
return MindMapResult(output=merge_json)
def _merge(self, d1, d2):
for k in d1:
if k in d2:
if isinstance(d1[k], dict) and isinstance(d2[k], dict):
self._merge(d1[k], d2[k])
elif isinstance(d1[k], list) and isinstance(d2[k], list):
d2[k].extend(d1[k])
else:
d2[k] = d1[k]
else:
d2[k] = d1[k]
return d2
def _list_to_kv(self, data):
for key, value in data.items():
if isinstance(value, dict):
self._list_to_kv(value)
elif isinstance(value, list):
new_value = {}
for i in range(len(value)):
if isinstance(value[i], list) and i > 0:
new_value[value[i - 1]] = value[i][0]
data[key] = new_value
else:
continue
return data
def _todict(self, layer: collections.OrderedDict):
to_ret = layer
if isinstance(layer, collections.OrderedDict):
to_ret = dict(layer)
try:
for key, value in to_ret.items():
to_ret[key] = self._todict(value)
except AttributeError:
pass
return self._list_to_kv(to_ret)
async def _process_document(
self, text: str, prompt_variables: dict[str, str], out_res
):
variables = {
**prompt_variables,
self._input_text_key: text,
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {}))
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))
out_res.append(self._todict(markdown_to_json.dictify(response)))

View File

@@ -0,0 +1,19 @@
MIND_MAP_EXTRACTION_PROMPT = """
- Role: You're a talent text processor to summarize a piece of text into a mind map.
- Step of task:
1. Generate a title for user's 'TEXT'
2. Classify the 'TEXT' into sections of a mind map.
3. If the subject matter is really complex, split them into sub-sections and sub-subsections.
4. Add a shot content summary of the bottom level section.
- Output requirement:
- Generate at least 4 levels.
- Always try to maximize the number of sub-sections.
- In language of 'Text'
- MUST IN FORMAT OF MARKDOWN
-TEXT-
{input_text}
"""

View File

@@ -0,0 +1,131 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from dataclasses import dataclass
from typing import Any
import networkx as nx
import trio
from app.core.rag.graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
from app.core.rag.graphrag.light.graph_prompt import PROMPTS
from app.core.rag.graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split_string_by_multi_markers
from app.core.rag.llm.chat_model import Base as CompletionLLM
from app.core.rag.common.token_utils import num_tokens_from_string
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
_max_gleanings: int
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
example_number: int = 2,
max_gleanings: int | None = None,
):
super().__init__(llm_invoker, language, entity_types)
"""Init method definition."""
self._max_gleanings = max_gleanings if max_gleanings is not None else ENTITY_EXTRACTION_MAX_GLEANINGS
self._example_number = example_number
examples = "\n".join(PROMPTS["entity_extraction_examples"][: int(self._example_number)])
example_context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(self._entity_types),
language=self._language,
)
# add example's format
examples = examples.format(**example_context_base)
self._entity_extract_prompt = PROMPTS["entity_extraction"]
self._context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(self._entity_types),
examples=examples,
language=self._language,
)
self._continue_prompt = PROMPTS["entity_continue_extraction"].format(**self._context_base)
self._if_loop_prompt = PROMPTS["entity_if_loop_extraction"]
self._left_token_count = getattr(llm_invoker, 'max_length', 8096) - num_tokens_from_string(self._entity_extract_prompt.format(**self._context_base, input_text=""))
self._left_token_count = max(getattr(llm_invoker, 'max_length', 8096) * 0.6, self._left_token_count)
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
hint_prompt = self._entity_extract_prompt.format(**self._context_base, input_text=content)
gen_conf = {}
final_result = ""
glean_result = ""
if_loop_result = ""
history = []
logging.info(f"Start processing for {chunk_key}: {content[:25]}...")
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
history.extend([{"role": "assistant", "content": if_loop_result}, {"role": "user", "content": self._continue_prompt}])
logging.info(f"Completed processing for {chunk_key}: {content[:25]}... after {now_glean_index} gleanings, {token_count} tokens.")
if self.callback:
self.callback(msg=f"Completed processing for {chunk_key}: {content[:25]}... after {now_glean_index} gleanings, {token_count} tokens.")
records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(
0.5 + 0.1 * len(out_results) / num_chunks,
msg=f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.",
)

View File

@@ -0,0 +1,331 @@
# Licensed under the MIT License
"""
Reference:
- [LightRAG](https://github.com/HKUDS/LightRAG/blob/main/lightrag/prompt.py)
"""
from typing import Any
PROMPTS: dict[str, Any] = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["DEFAULT_USER_PROMPT"] = "n/a"
PROMPTS["entity_extraction"] = """---Goal---
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
---Steps---
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
- entity_type: One of the following types: [{entity_types}]
- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text."
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
######################
---Examples---
######################
{examples}
#############################
---Real Data---
######################
Entity_types: [{entity_types}]
Text:
{input_text}
######################
Output:"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
```
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. "If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us."
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
```
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
#############################""",
"""Example 2:
Entity_types: [company, index, commodity, market_trend, economic_policy, biological]
Text:
```
Stock markets faced a sharp downturn today as tech giants saw significant declines, with the Global Tech Index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty.
Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted a modest 2.1% gain, driven by rising oil prices.
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%, reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices continued their rally, climbing to $87.60 per barrel, supported by supply constraints and strong demand.
Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability.
```
Output:
("entity"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"index"{tuple_delimiter}"The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today."){record_delimiter}
("entity"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"company"{tuple_delimiter}"Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings."){record_delimiter}
("entity"{tuple_delimiter}"Omega Energy"{tuple_delimiter}"company"{tuple_delimiter}"Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices."){record_delimiter}
("entity"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"commodity"{tuple_delimiter}"Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets."){record_delimiter}
("entity"{tuple_delimiter}"Crude Oil"{tuple_delimiter}"commodity"{tuple_delimiter}"Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand."){record_delimiter}
("entity"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"market_trend"{tuple_delimiter}"Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations."){record_delimiter}
("entity"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"economic_policy"{tuple_delimiter}"The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability."){record_delimiter}
("relationship"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns."{tuple_delimiter}"market performance, investor sentiment"{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Nexon Technologies"{tuple_delimiter}"Global Tech Index"{tuple_delimiter}"Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index."{tuple_delimiter}"company impact, index movement"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Gold Futures"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Gold prices rose as investors sought safe-haven assets during the market selloff."{tuple_delimiter}"market reaction, safe-haven investment"{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Federal Reserve Policy Announcement"{tuple_delimiter}"Market Selloff"{tuple_delimiter}"Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff."{tuple_delimiter}"interest rate impact, financial regulation"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"market downturn, investor sentiment, commodities, Federal Reserve, stock performance"){completion_delimiter}
#############################""",
"""Example 3:
Entity_types: [economic_policy, athlete, event, location, record, organization, equipment]
Text:
```
At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes.
```
Output:
("entity"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"event"{tuple_delimiter}"The World Athletics Championship is a global sports competition featuring top athletes in track and field."){record_delimiter}
("entity"{tuple_delimiter}"Tokyo"{tuple_delimiter}"location"{tuple_delimiter}"Tokyo is the host city of the World Athletics Championship."){record_delimiter}
("entity"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"athlete"{tuple_delimiter}"Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship."){record_delimiter}
("entity"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"record"{tuple_delimiter}"The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter."){record_delimiter}
("entity"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"equipment"{tuple_delimiter}"Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction."){record_delimiter}
("entity"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"organization"{tuple_delimiter}"The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations."){record_delimiter}
("relationship"{tuple_delimiter}"World Athletics Championship"{tuple_delimiter}"Tokyo"{tuple_delimiter}"The World Athletics Championship is being hosted in Tokyo."{tuple_delimiter}"event location, international competition"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"Noah Carter set a new 100m sprint record at the championship."{tuple_delimiter}"athlete achievement, record-breaking"{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Noah Carter"{tuple_delimiter}"Carbon-Fiber Spikes"{tuple_delimiter}"Noah Carter used carbon-fiber spikes to enhance performance during the race."{tuple_delimiter}"athletic equipment, performance boost"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"World Athletics Federation"{tuple_delimiter}"100m Sprint Record"{tuple_delimiter}"The World Athletics Federation is responsible for validating and recognizing new sprint records."{tuple_delimiter}"sports regulation, record certification"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"athletics, sprinting, record-breaking, sports technology, competition"){completion_delimiter}
#############################""",
]
PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
Use {language} as output language.
#######
---Data---
Entities: {entity_name}
Description List: {description_list}
#######
Output:
"""
PROMPTS["entity_continue_extraction"] = """
MANY entities and relationships were missed in the last extraction. Please find only the missing entities and relationships from previous text.
---Remember Steps---
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name
- entity_type: One of the following types: [{entity_types}]
- entity_description: Provide a comprehensive description of the entity's attributes and activities *based solely on the information present in the input text*. **Do not infer or hallucinate information not explicitly stated.** If the text provides insufficient information to create a comprehensive description, state "Description not available in text."
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
---Output---
Add new entities and relations below using the same format, and do not include entities and relations that have been previously extracted. :\n
""".strip()
PROMPTS["entity_if_loop_extraction"] = """
---Goal---'
It appears some entities may have still been missed.
---Output---
Answer ONLY by `YES` OR `NO` if there are still entities that need to be added.
""".strip()
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question.[no-context]"
PROMPTS["rag_response"] = """---Role---
You are a helpful assistant responding to user query about Knowledge Graph and Document Chunks provided in JSON format below.
---Goal---
Generate a concise response based on Knowledge Base and follow Response Rules, considering both current query and the conversation history if provided. Summarize all information in the provided Knowledge Base, and incorporating general knowledge relevant to the Knowledge Base. Do not include information not provided by Knowledge Base.
---Conversation History---
{history}
---Knowledge Graph and Document Chunks---
{context_data}
---RESPONSE GUIDELINES---
**1. Content & Adherence:**
- Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data.
- If the answer cannot be found in the provided context, state that you do not have enough information to answer.
- Ensure the response maintains continuity with the conversation history.
**2. Formatting & Language:**
- Format the response using markdown with appropriate section headings.
- The response language must in the same language as the user's question.
- Target format and length: {response_type}
**3. Citations / References:**
- At the end of the response, under a "References" section, each citation must clearly indicate its origin (KG or DC).
- The maximum number of citations is 5, including both KG and DC.
- Use the following formats for citations:
- For a Knowledge Graph Entity: `[KG] <entity_name>`
- For a Knowledge Graph Relationship: `[KG] <entity1_name> - <entity2_name>`
- For a Document Chunk: `[DC] <file_path_or_document_name>`
---USER CONTEXT---
- Additional user prompt: {user_prompt}
Response:"""
PROMPTS["keywords_extraction"] = """---Role---
You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval.
---Goal---
Given a user query, your task is to extract two distinct types of keywords:
1. **high_level_keywords**: for overarching concepts or themes, capturing user's core intent, the subject area, or the type of question being asked.
2. **low_level_keywords**: for specific entities or details, identifying the specific entities, proper nouns, technical jargon, product names, or concrete items.
---Instructions & Constraints---
1. **Output Format**: Your output MUST be a valid JSON object and nothing else. Do not include any explanatory text, markdown code fences (like ```json), or any other text before or after the JSON. It will be parsed directly by a JSON parser.
2. **Source of Truth**: All keywords must be explicitly derived from the user query, with both high-level and low-level keyword categories required to contain content.
3. **Concise & Meaningful**: Keywords should be concise words or meaningful phrases. Prioritize multi-word phrases when they represent a single concept. For example, from "latest financial report of Apple Inc.", you should extract "latest financial report" and "Apple Inc." rather than "latest", "financial", "report", and "Apple".
4. **Handle Edge Cases**: For queries that are too simple, vague, or nonsensical (e.g., "hello", "ok", "asdfghjkl"), you must return a JSON object with empty lists for both keyword types.
---Examples---
{examples}
---Real Data---
User Query: {query}
---Output---
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?"
Output:
{
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}
""",
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
Output:
{
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}
""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
Output:
{
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}
""",
]
PROMPTS["naive_rag_response"] = """---Role---
You are a helpful assistant responding to user query about Document Chunks provided provided in JSON format below.
---Goal---
Generate a concise response based on Document Chunks and follow Response Rules, considering both the conversation history and the current query. Summarize all information in the provided Document Chunks, and incorporating general knowledge relevant to the Document Chunks. Do not include information not provided by Document Chunks.
---Conversation History---
{history}
---Document Chunks(DC)---
{content_data}
---RESPONSE GUIDELINES---
**1. Content & Adherence:**
- Strictly adhere to the provided context from the Knowledge Base. Do not invent, assume, or include any information not present in the source data.
- If the answer cannot be found in the provided context, state that you do not have enough information to answer.
- Ensure the response maintains continuity with the conversation history.
**2. Formatting & Language:**
- Format the response using markdown with appropriate section headings.
- The response language must match the user's question language.
- Target format and length: {response_type}
**3. Citations / References:**
- At the end of the response, under a "References" section, cite a maximum of 5 most relevant sources used.
- Use the following formats for citations: `[DC] <file_path_or_document_name>`
---USER CONTEXT---
- Additional user prompt: {user_prompt}
Response:"""

View File

@@ -0,0 +1,218 @@
# Licensed under the MIT License
"""
Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
- [MiniRAG](https://github.com/HKUDS/MiniRAG)
"""
PROMPTS = {}
PROMPTS["minirag_query2kwd"] = """---Role---
You are a helpful assistant tasked with identifying both answer-type and low-level keywords in the user's query.
---Goal---
Given the query, list both answer-type and low-level keywords.
answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms.
The answer_type_keywords must be selected from Answer type pool.
This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have three keys:
- "answer_type_keywords" for the types of the answer. In this list, the types with the highest likelihood should be placed at the forefront. No more than 3.
- "entities_from_query" for specific entities or details. It must be extracted from the query.
######################
-Examples-
######################
Example 1:
Query: "How does international trade influence global economic stability?"
Answer type pool: {{
'PERSONAL LIFE': ['FAMILY TIME', 'HOME MAINTENANCE'],
'STRATEGY': ['MARKETING PLAN', 'BUSINESS EXPANSION'],
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
'PERSON': ['JANE DOE', 'JOHN SMITH'],
'FOOD': ['PASTA', 'SUSHI'],
'EMOTION': ['HAPPINESS', 'ANGER'],
'PERSONAL EXPERIENCE': ['TRAVEL ABROAD', 'STUDYING ABROAD'],
'INTERACTION': ['TEAM MEETING', 'NETWORKING EVENT'],
'BEVERAGE': ['COFFEE', 'TEA'],
'PLAN': ['ANNUAL BUDGET', 'PROJECT TIMELINE'],
'GEO': ['NEW YORK CITY', 'SOUTH AFRICA'],
'GEAR': ['CAMPING TENT', 'CYCLING HELMET'],
'EMOJI': ['🎉', '🚀'],
'BEHAVIOR': ['POSITIVE FEEDBACK', 'NEGATIVE CRITICISM'],
'TONE': ['FORMAL', 'INFORMAL'],
'LOCATION': ['DOWNTOWN', 'SUBURBS']
}}
################
Output:
{{
"answer_type_keywords": ["STRATEGY","PERSONAL LIFE"],
"entities_from_query": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}}
#############################
Example 2:
Query: "When was SpaceX's first rocket launch?"
Answer type pool: {{
'DATE AND TIME': ['2023-10-10 10:00', 'THIS AFTERNOON'],
'ORGANIZATION': ['GLOBAL INITIATIVES CORPORATION', 'LOCAL COMMUNITY CENTER'],
'PERSONAL LIFE': ['DAILY EXERCISE ROUTINE', 'FAMILY VACATION PLANNING'],
'STRATEGY': ['NEW PRODUCT LAUNCH', 'YEAR-END SALES BOOST'],
'SERVICE FACILITATION': ['REMOTE IT SUPPORT', 'ON-SITE TRAINING SESSIONS'],
'PERSON': ['ALEXANDER HAMILTON', 'MARIA CURIE'],
'FOOD': ['GRILLED SALMON', 'VEGETARIAN BURRITO'],
'EMOTION': ['EXCITEMENT', 'DISAPPOINTMENT'],
'PERSONAL EXPERIENCE': ['BIRTHDAY CELEBRATION', 'FIRST MARATHON'],
'INTERACTION': ['OFFICE WATER COOLER CHAT', 'ONLINE FORUM DEBATE'],
'BEVERAGE': ['ICED COFFEE', 'GREEN SMOOTHIE'],
'PLAN': ['WEEKLY MEETING SCHEDULE', 'MONTHLY BUDGET OVERVIEW'],
'GEO': ['MOUNT EVEREST BASE CAMP', 'THE GREAT BARRIER REEF'],
'GEAR': ['PROFESSIONAL CAMERA EQUIPMENT', 'OUTDOOR HIKING GEAR'],
'EMOJI': ['📅', ''],
'BEHAVIOR': ['PUNCTUALITY', 'HONESTY'],
'TONE': ['CONFIDENTIAL', 'SATIRICAL'],
'LOCATION': ['CENTRAL PARK', 'DOWNTOWN LIBRARY']
}}
################
Output:
{{
"answer_type_keywords": ["DATE AND TIME", "ORGANIZATION", "PLAN"],
"entities_from_query": ["SpaceX", "Rocket launch", "Aerospace", "Power Recovery"]
}}
#############################
Example 3:
Query: "What is the role of education in reducing poverty?"
Answer type pool: {{
'PERSONAL LIFE': ['MANAGING WORK-LIFE BALANCE', 'HOME IMPROVEMENT PROJECTS'],
'STRATEGY': ['MARKETING STRATEGIES FOR Q4', 'EXPANDING INTO NEW MARKETS'],
'SERVICE FACILITATION': ['CUSTOMER SATISFACTION SURVEYS', 'STAFF RETENTION PROGRAMS'],
'PERSON': ['ALBERT EINSTEIN', 'MARIA CALLAS'],
'FOOD': ['PAN-FRIED STEAK', 'POACHED EGGS'],
'EMOTION': ['OVERWHELM', 'CONTENTMENT'],
'PERSONAL EXPERIENCE': ['LIVING ABROAD', 'STARTING A NEW JOB'],
'INTERACTION': ['SOCIAL MEDIA ENGAGEMENT', 'PUBLIC SPEAKING'],
'BEVERAGE': ['CAPPUCCINO', 'MATCHA LATTE'],
'PLAN': ['ANNUAL FITNESS GOALS', 'QUARTERLY BUSINESS REVIEW'],
'GEO': ['THE AMAZON RAINFOREST', 'THE GRAND CANYON'],
'GEAR': ['SURFING ESSENTIALS', 'CYCLING ACCESSORIES'],
'EMOJI': ['💻', '📱'],
'BEHAVIOR': ['TEAMWORK', 'LEADERSHIP'],
'TONE': ['FORMAL MEETING', 'CASUAL CONVERSATION'],
'LOCATION': ['URBAN CITY CENTER', 'RURAL COUNTRYSIDE']
}}
################
Output:
{{
"answer_type_keywords": ["STRATEGY", "PERSON"],
"entities_from_query": ["School access", "Literacy rates", "Job training", "Income inequality"]
}}
#############################
Example 4:
Query: "Where is the capital of the United States?"
Answer type pool: {{
'ORGANIZATION': ['GREENPEACE', 'RED CROSS'],
'PERSONAL LIFE': ['DAILY WORKOUT', 'HOME COOKING'],
'STRATEGY': ['FINANCIAL INVESTMENT', 'BUSINESS EXPANSION'],
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
'PERSON': ['ALBERTA SMITH', 'BENJAMIN JONES'],
'FOOD': ['PASTA CARBONARA', 'SUSHI PLATTER'],
'EMOTION': ['HAPPINESS', 'SADNESS'],
'PERSONAL EXPERIENCE': ['TRAVEL ADVENTURE', 'BOOK CLUB'],
'INTERACTION': ['TEAM BUILDING', 'NETWORKING MEETUP'],
'BEVERAGE': ['LATTE', 'GREEN TEA'],
'PLAN': ['WEIGHT LOSS', 'CAREER DEVELOPMENT'],
'GEO': ['PARIS', 'NEW YORK'],
'GEAR': ['CAMERA', 'HEADPHONES'],
'EMOJI': ['🏢', '🌍'],
'BEHAVIOR': ['POSITIVE THINKING', 'STRESS MANAGEMENT'],
'TONE': ['FRIENDLY', 'PROFESSIONAL'],
'LOCATION': ['DOWNTOWN', 'SUBURBS']
}}
################
Output:
{{
"answer_type_keywords": ["LOCATION"],
"entities_from_query": ["capital of the United States", "Washington", "New York"]
}}
#############################
-Real Data-
######################
Query: {query}
Answer type pool:{TYPE_POOL}
######################
Output:
"""
PROMPTS["keywords_extraction"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
---Goal---
Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have two keys:
- "high_level_keywords" for overarching concepts or themes.
- "low_level_keywords" for specific entities or details.
######################
-Examples-
######################
{examples}
#############################
-Real Data-
######################
Query: {query}
######################
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?"
################
Output:
{
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}
#############################""",
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
################
Output:
{
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}
#############################""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
################
Output:
{
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}
#############################""",
]

View File

@@ -0,0 +1,301 @@
import json
import logging
from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
import trio
from app.core.rag.common.misc_utils import get_uuid
from app.core.rag.graphrag.query_analyze_prompt import PROMPTS
from app.core.rag.common.token_utils import num_tokens_from_string
from app.core.rag.utils.doc_store_conn import OrderByExpr
from app.core.rag.nlp.search import Dealer, index_name
from app.core.rag.common.float_utils import get_float
class KGSearch(Dealer):
def _chat(self, llm_bdl, system, history, gen_conf):
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
response = get_llm_cache(llm_bdl.model_name, system, history, gen_conf)
if response:
return response
response = llm_bdl.chat(system, history, gen_conf)
if isinstance(response, tuple):
response = response[0]
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(llm_bdl.model_name, system, response, history, gen_conf)
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
from app.core.rag.graphrag.utils import get_entity_type2samples
ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids))
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
try:
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
except json_repair.JSONDecodeError:
try:
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
# Handle parsing error
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
flds = ["page_content", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.getFields(es_res, flds)
for _, ent in es_res.items():
for f in flds:
if f in ent and ent[f] is None:
del ent[f]
if get_float(ent.get("_score", 0)) < sim_thr:
continue
if isinstance(ent["entity_kwd"], list):
ent["entity_kwd"] = ent["entity_kwd"][0]
res[ent["entity_kwd"]] = {
"sim": get_float(ent.get("_score", 0)),
"pagerank": get_float(ent.get("rank_flt", 0)),
"n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
"description": ent.get("page_content", "{}")
}
return res
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["page_content", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if get_float(ent["_score"]) < sim_thr:
continue
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
if isinstance(f, list):
f = f[0]
if isinstance(t, list):
t = t[0]
res[(f, t)] = {
"sim": get_float(ent["_score"]),
"pagerank": get_float(ent.get("weight_int", 0)),
"description": ent["page_content"]
}
return res
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not keywords:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(["page_content", "entity_kwd", "rank_flt"], [], filters, [matchDense],
OrderByExpr(), 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, sim_thr)
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not txt:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "relation"
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(
["page_content", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
return self._relation_info_from_(es_res, sim_thr)
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
if not types:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
filters["entity_type_kwd"] = types
ordr = OrderByExpr()
ordr.desc("rank_flt")
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, 0)
def retrieval(self, question: str,
workspace_ids: str | list[str],
kb_ids: list[str],
emb_mdl,
llm,
max_token: int = 8196,
ent_topn: int = 6,
rel_topn: int = 6,
comm_topn: int = 1,
ent_sim_threshold: float = 0.3,
rel_sim_threshold: float = 0.3,
**kwargs
):
qst = question
filters = self.get_filters({"kb_ids": kb_ids})
if isinstance(workspace_ids, str):
workspace_ids = workspace_ids.split(",")
idxnms = [index_name(workspace_id) for workspace_id in workspace_ids]
ty_kwds = []
try:
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(workspace_id) for workspace_id in workspace_ids], kb_ids)
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
except Exception as e:
logging.exception(e)
ents = [qst]
pass
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
nhop_pathes = defaultdict(dict)
for _, ent in ents_from_query.items():
nhops = ent.get("n_hop_ents", [])
if not isinstance(nhops, list):
logging.warning(f"Abnormal n_hop_ents: {nhops}")
continue
for nbr in nhops:
path = nbr["path"]
wts = nbr["weights"]
for i in range(len(path) - 1):
f, t = path[i], path[i + 1]
if (f, t) in nhop_pathes:
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
else:
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
nhop_pathes[(f, t)]["pagerank"] = wts[i]
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
# P(E|Q) => P(E) * P(Q|E) => pagerank * sim
for ent in ents_from_types.keys():
if ent not in ents_from_query:
continue
ents_from_query[ent]["sim"] *= 2
for (f, t) in rels_from_txt.keys():
pair = tuple(sorted([f, t]))
s = 0
if pair in nhop_pathes:
s += nhop_pathes[pair]["sim"]
del nhop_pathes[pair]
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)]["sim"] *= s + 1
# This is for the relations from n-hop but not by query search
for (f, t) in nhop_pathes.keys():
s = 0
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)] = {
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
"pagerank": nhop_pathes[(f, t)]["pagerank"]
}
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:ent_topn]
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:rel_topn]
ents = []
relas = []
for n, ent in ents_from_query:
ents.append({
"Entity": n,
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "") if ent["description"] else ""
})
max_token -= num_tokens_from_string(str(ents[-1]))
if max_token <= 0:
ents = ents[:-1]
break
for (f, t), rel in rels_from_txt:
if not rel.get("description"):
for workspace_id in workspace_ids:
from app.core.rag.graphrag.utils import get_relation
rela = get_relation(workspace_id, kb_ids, f, t)
if rela:
break
else:
continue
rel["description"] = rela["description"]
desc = rel["description"]
try:
desc = json.loads(desc).get("description", "")
except Exception:
pass
relas.append({
"From Entity": f,
"To Entity": t,
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
"Description": desc
})
max_token -= num_tokens_from_string(str(relas[-1]))
if max_token <= 0:
relas = relas[:-1]
break
if ents:
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
else:
ents = ""
if relas:
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
else:
relas = ""
return {
"chunk_id": get_uuid(),
"content_ltks": "",
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"document_id": "",
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval
fields = ["docnm_kwd", "page_content"]
odr = OrderByExpr()
odr.desc("weight_flt")
fltr = deepcopy(condition)
fltr["knowledge_graph_kwd"] = "community_report"
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["page_content"])
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
max_token -= num_tokens_from_string(str(txts[-1]))
if not txts:
return ""
return "\n---- Community Report ----\n" + "\n".join(txts)

View File

@@ -1,20 +1,102 @@
import xxhash
import redis
from app.core.config import settings
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
redis_client = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=30
)
import dataclasses
import html
import json
import logging
import os
import re
import time
from collections import defaultdict
from hashlib import md5
from typing import Any, Callable, Set, Tuple
import networkx as nx
import numpy as np
import redis
import trio
import xxhash
from networkx.readwrite import json_graph
from app.core.rag.common.misc_utils import get_uuid
from app.core.rag.common.connection_utils import timeout
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.utils.doc_store_conn import OrderByExpr
from app.core.rag.common import settings
from app.core.rag.utils.redis_conn import redis_conn_params
redis_client = redis.StrictRedis(**redis_conn_params)
GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass
class GraphChange:
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str:
"""Perform variable replacements on the input string and in a chat log."""
if history is None:
history = []
if variables is None:
variables = {}
result = input
def replace_all(input: str) -> str:
result = input
for k, v in variables.items():
result = result.replace(f"{{{k}}}", str(v))
return result
result = replace_all(result)
for i, entry in enumerate(history):
if entry.get("role") == "system":
entry["content"] = replace_all(entry.get("content") or "")
return result
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool:
"""Return True if the given dictionary has the given keys with the given types."""
for field, field_type in expected_fields:
if field not in data:
return False
value = data[field]
if not isinstance(value, field_type):
return False
return True
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
bin = redis_client.get(k)
@@ -25,6 +107,528 @@ def get_llm_cache(llmnm, txt, history, genconf):
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
redis_client.set(k, v.encode("utf-8"), 24 * 3600)
def get_embed_cache(llmnm, txt):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
bin = redis_client.get(k)
if not bin:
return
return np.array(json.loads(bin))
def set_embed_cache(llmnm, txt, arr):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
redis_client.set(k, arr.encode("utf-8"), 24 * 3600)
def get_tags_from_cache(kb_ids):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
bin = redis_client.get(k)
if not bin:
return
return bin
def set_tags_to_cache(kb_ids, tags):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
redis_client.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
"""
Ensure all nodes and edges in the graph have some essential attribute.
"""
def is_valid_item(node_attrs: dict) -> bool:
valid_node = True
for attr in ["description", "source_id"]:
if attr not in node_attrs:
valid_node = False
break
return valid_node
if check_attribute:
purged_nodes = []
for node, node_attrs in graph.nodes(data=True):
if not is_valid_item(node_attrs):
purged_nodes.append(node)
for node in purged_nodes:
graph.remove_node(node)
if purged_nodes and callback:
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
purged_edges = []
for source, target, attr in graph.edges(data=True):
if check_attribute:
if not is_valid_item(attr):
purged_edges.append((source, target))
if "keywords" not in attr:
attr["keywords"] = []
for source, target in purged_edges:
graph.remove_edge(source, target)
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
change.added_updated_nodes.add(node_name)
if not g1.has_node(node_name):
g1.add_node(node_name, **attr)
continue
node = g1.nodes[node_name]
node["description"] += GRAPH_FIELD_SEP + attr["description"]
# A node's source_id indicates which chunks it came from.
node["source_id"] += attr["source_id"]
for source, target, attr in g2.edges(data=True):
change.added_updated_edges.add(get_from_to(source, target))
edge = g1.get_edge_data(source, target)
if edge is None:
g1.add_edge(source, target, **attr)
continue
edge["weight"] += attr.get("weight", 0)
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
edge["keywords"] += attr["keywords"]
# A edge's source_id indicates which chunks it came from.
edge["source_id"] += attr["source_id"]
for node_degree in g1.degree:
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
# A graph's source_id indicates which documents it came from.
if "source_id" not in g1.graph:
g1.graph["source_id"] = []
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name.upper(),
entity_type=entity_type.upper(),
description=entity_description,
source_id=entity_source_id,
)
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key
weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
pair = sorted([source.upper(), target.upper()])
return dict(
src_id=pair[0],
tgt_id=pair[1],
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time()},
)
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def chunk_id(chunk):
return xxhash.xxh64((chunk["page_content"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
"knowledge_graph_kwd": "entity",
"entity_type_kwd": meta["entity_type"],
"page_content": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": meta["source_id"],
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
ebd = get_embed_cache(embd_mdl.model_name, ent_name)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0]
set_embed_cache(embd_mdl.model_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
@timeout(3, 3)
def get_relation(workspace_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = from_ent_name
if isinstance(ents, str):
ents = [from_ent_name]
if isinstance(to_ent_name, str):
to_ent_name = [to_ent_name]
ents.extend(to_ent_name)
ents = list(set(ents))
conds = {"fields": ["page_content"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = settings.retriever.search(conds, search.index_name(workspace_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
return json.loads(es_res.field[id]["page_content"])
res.append(json.loads(es_res.field[id]["page_content"]))
except Exception:
continue
return res
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"page_content": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": meta["source_id"],
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.model_name, txt)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
ebd = ebd[0]
set_embed_cache(embd_mdl.model_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(workspace_id, kb_id, document_id):
# Get document_ids of graph
fields = ["source_id"]
condition = {
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(workspace_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
graph_document_ids = set()
for chunk_id in fields2.keys():
graph_document_ids = set(fields2[chunk_id]["source_id"])
return document_id in graph_document_ids
async def get_graph_document_ids(workspace_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(workspace_id), [kb_id]))
document_ids = []
if res.total == 0:
return document_ids
for id in res.ids:
document_ids = res.field[id]["source_id"]
return document_ids
async def get_graph(workspace_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["page_content", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(workspace_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
if res.field[id]["removed_kwd"] == "N":
g = json_graph.node_link_graph(json.loads(res.field[id]["page_content"]), edges="edges")
if "source_id" not in g.graph:
g.graph["source_id"] = res.field[id]["source_id"]
else:
g = await rebuild_graph(workspace_id, kb_id, exclude_rebuild)
return g
except Exception:
continue
result = None
return result
async def set_graph(workspace_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = trio.current_time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(workspace_id), kb_id)
if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(workspace_id), kb_id)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(workspace_id), kb_id
)
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node)
now = trio.current_time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
chunks = [
{
"id": get_uuid(),
"page_content": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": graph.graph.get("source_id", []),
"available_int": 0,
"removed_kwd": "N",
}
]
# generate updated subgraphs
for source in graph.graph["source_id"]:
subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
subgraph.graph["source_id"] = [source]
for n in subgraph.nodes:
subgraph.nodes[n]["source_id"] = [source]
chunks.append(
{
"id": get_uuid(),
"page_content": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [source],
"available_int": 0,
"removed_kwd": "N",
}
)
async with trio.open_nursery() as nursery:
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
async with trio.open_nursery() as nursery:
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
now = trio.current_time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(workspace_id), kb_id))
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch status!"
raise Exception(error_message)
now = trio.current_time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
def is_continuous_subsequence(subseq, seq):
def find_all_indexes(tup, value):
indexes = []
start = 0
while True:
try:
index = tup.index(value, start)
indexes.append(index)
start = index + 1
except ValueError:
break
return indexes
index_list = find_all_indexes(seq, subseq[0])
for idx in index_list:
if idx != len(seq) - 1:
if seq[idx + 1] == subseq[-1]:
return True
return False
def merge_tuples(list1, list2):
result = []
for tup in list1:
last_element = tup[-1]
if last_element in tup[:-1]:
result.append(tup)
else:
matching_tuples = [t for t in list2 if t[0] == last_element]
already_match_flag = 0
for match in matching_tuples:
matchh = (match[1], match[0])
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
continue
already_match_flag = 1
merged_tuple = tup + match[1:]
result.append(merged_tuple)
if not already_match_flag:
result.append(tup)
return result
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["page_content"]}, idxnms, kb_ids))
res = defaultdict(list)
for id in es_res.ids:
smp = es_res.field[id].get("page_content")
if not smp:
continue
try:
smp = json.loads(smp)
except Exception as e:
logging.exception(e)
for ty, ents in smp.items():
res[ty].extend(ents)
return res
def flat_uniq_list(arr, key):
res = []
for a in arr:
a = a[key]
if isinstance(a, list):
res.extend(a)
else:
res.append(a)
return list(set(res))
async def rebuild_graph(workspace_id, kb_id, exclude_rebuild=None):
graph = nx.Graph()
flds = ["knowledge_graph_kwd", "page_content", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(workspace_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0:
break
for id, d in es_res.items():
assert d["knowledge_graph_kwd"] == "subgraph"
if isinstance(exclude_rebuild, list):
if sum([n in d["source_id"] for n in exclude_rebuild]):
continue
elif exclude_rebuild in d["source_id"]:
continue
next_graph = json_graph.node_link_graph(json.loads(d["page_content"]), edges="edges")
merged_graph = nx.compose(graph, next_graph)
merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes}
nx.set_node_attributes(merged_graph, merged_source, "source_id")
if "source_id" in graph.graph:
merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
else:
merged_graph.graph["source_id"] = next_graph.graph["source_id"]
graph = merged_graph
if len(graph.nodes) == 0:
return None
graph.graph["source_id"] = sorted(graph.graph["source_id"])
return graph
def has_canceled(task_id):
try:
if redis_client.get(f"{task_id}-cancel"):
return True
except Exception as e:
logging.exception(e)
return False

View File

@@ -0,0 +1,290 @@
import json
from abc import ABC
from urllib.parse import urljoin
import dashscope
import numpy as np
import requests
from openai import OpenAI
from app.core.rag.common.log_utils import log_exception
from app.core.rag.common.token_utils import num_tokens_from_string, truncate
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Constructor for abstract base class.
Parameters are accepted for interface consistency but are not stored.
Subclasses should implement their own initialization as needed.
"""
pass
def encode(self, texts: list):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class OpenAIEmbed(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
# OpenAI requires batch size <=16
batch_size = 16
texts = [truncate(t, 8191) for t in texts]
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
def encode(self, texts: list):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
log_exception(_e, res)
# local embedding for LmStudio donot count tokens
return np.array(ress), 1024
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name
def encode(self, texts: list):
import time
import dashscope
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
retry_max = 5
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
time.sleep(10)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
retry_max -= 1
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
if resp.get("message"):
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
else:
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
raise
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
raise
return np.array(res), token_count
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = None
try:
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = None
try:
res = self.client.embeddings.create(input=[text], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
class NvidiaEmbed(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
self.base_url = base_url
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"authorization": f"Bearer {self.api_key}",
}
self.model_name = model_name
if model_name == "nvidia/embed-qa-4":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
self.model_name = "NV-Embed-QA"
if model_name == "snowflake/arctic-embed-l":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
payload = {
"input": texts[i : i + batch_size],
"input_type": "query",
"model": self.model_name,
"encoding_format": "float",
"truncate": "END",
}
response = requests.post(self.base_url, headers=self.headers, json=payload)
try:
res = response.json()
except Exception as _e:
log_exception(_e, response)
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embeddings = response.json()
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text: str):
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()[0]
return np.array(embedding), num_tokens_from_string(text)
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

View File

@@ -1,8 +1,16 @@
import json
import logging
import re
import math
import os
from collections import OrderedDict
from dataclasses import dataclass
import uuid
from typing import Dict, List, Any
import numpy as np
from sqlalchemy.orm import Session
from langchain_core.documents import Document
from app.db import get_db
from app.core.models.base import RedBearModelConfig
from app.core.models import RedBearLLM, RedBearRerank
@@ -12,6 +20,12 @@ from app.core.rag.models.chunk import DocumentChunk
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.services.model_service import ModelConfigService
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.prompts.generator import relevant_chunks_with_toc
from app.core.rag.nlp import rag_tokenizer, query
from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
from app.core.rag.common.string_utils import remove_redundant_spaces
from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
def knowledge_retrieval(
@@ -46,7 +60,7 @@ def knowledge_retrieval(
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
file_names_filter=[]
file_names_filter = []
if user_ids:
file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids])
@@ -190,3 +204,568 @@ def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk]
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def index_name(uid): return f"graphrag_{uid}"
class Dealer:
def __init__(self, dataStore: DocStoreConnection):
self.qryr = query.FulltextQueryer()
self.dataStore = dataStore
@dataclass
class SearchResult:
total: int
ids: list[str]
query_vector: list[float] | None = None
field: dict | None = None
highlight: dict | None = None
aggregation: list | dict | None = None
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
def get_filters(self, req):
condition = dict()
for key, field in {"kb_ids": "kb_id", "document_ids": "document_id"}.items():
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight: bool | list | None = None,
rank_feature: dict | None = None
):
if highlight is None:
highlight = False
filters = self.get_filters(req)
orderBy = OrderByExpr()
pg = int(req.get("page", 1)) - 1
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
offset, limit = pg * ps, ps
src = req.get("fields",
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"document_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
"question_kwd", "question_tks", "doc_type_kwd",
"available_int", "page_content", PAGERANK_FLD, TAG_FLD])
kwds = set([])
qst = req.get("question", "")
q_vec = []
if not qst:
if req.get("sort"):
orderBy.asc("page_num_int")
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"]
if not highlight:
highlightFields = []
elif isinstance(highlight, list):
highlightFields = highlight
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("document_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "page_content")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src + ["_score"]),
keywords=keywords
)
@staticmethod
def trans2floats(txt):
return [get_float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v)
if not chunks:
return answer, set([])
pieces = re.split(r"(```)", answer)
if len(pieces) >= 3:
i = 0
pieces_ = []
while i < len(pieces):
if pieces[i] == "```":
st = i
i += 1
while i < len(pieces) and pieces[i] != "```":
i += 1
if i < len(pieces):
i += 1
pieces_.append("".join(pieces[st: i]) + "\n")
else:
pieces_.extend(
re.split(
r"([^\|][;。?!\n]|[a-z][.?;!][ \n])",
pieces[i]))
i += 1
pieces = pieces_
else:
pieces = re.split(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", pieces[i]):
pieces[i - 1] += pieces[i][0]
pieces[i] = pieces[i][1:]
idx = []
pieces_ = []
for i, t in enumerate(pieces):
if len(t) < 5:
continue
idx.append(i)
pieces_.append(t)
logging.debug("{} => {}".format(answer, pieces_))
if not pieces_:
return answer, set([])
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
for ck in chunks]
cites = {}
thr = 0.63
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
rag_tokenizer.tokenize(
self.qryr.rmWWW(pieces_[i])).split(),
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99
logging.debug("{} SIM: {}".format(pieces_[i], mx))
if mx < thr:
continue
cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
thr *= 0.8
res = ""
seted = set([])
for i, p in enumerate(pieces):
res += p
if i not in idx:
continue
if i not in cites:
continue
for c in cites[i]:
assert int(c) < len(chunk_v)
for c in cites[i]:
if c in seted:
continue
res += f" [ID:{c}]"
seted.add(c)
return res, seted
def _rank_feature_scores(self, query_rfea, search_res):
## For rank feature(tag_fea) scores.
rank_fea = []
pageranks = []
for chunk_id in search_res.ids:
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
pageranks = np.array(pageranks, dtype=float)
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
rank_fea.append(0)
continue
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [get_float(v) for v in vector.split("\t")]
ins_embd.append(vector)
if not ins_embd:
return [], [], []
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split()))
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
ins_tw.append(tks)
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query)
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = sres.field[i][cfield].split()
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks + important_kwd
ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim, _ = rerank_mdl.similarity(query, [remove_redundant_spaces(" ".join(tks)) for tks in ins_tw])
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, workspace_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, document_ids=None, aggs=True,
rerank_mdl=None, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
# Ensure RERANK_LIMIT is multiple of page_size
RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
req = {"kb_ids": kb_ids, "document_ids": document_ids, "page": math.ceil(page_size * page / RERANK_LIMIT),
"size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
if isinstance(workspace_ids, str):
workspace_ids = workspace_ids.split(",")
sres = self.search(req, [index_name(workspace_id) for workspace_id in workspace_ids],
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else:
# ElasticSearch doesn't normalize each way score before fusion.
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
# Already paginated in search function
max_pages = RERANK_LIMIT // page_size
page_index = (page % max_pages) - 1
begin = max(page_index * page_size, 0)
sim = sim[begin: begin + page_size]
sim_np = np.array(sim, dtype=np.float64)
idx = np.argsort(sim_np * -1)
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if np.float64(sim[i]) < similarity_threshold:
break
id = sres.ids[i]
chunk = sres.field[id]
dnm = chunk.get("docnm_kwd", "")
did = chunk.get("document_id", "")
if len(ranks["chunks"]) >= page_size:
if aggs:
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"document_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
continue
break
position_int = chunk.get("position_int", [])
d = {
"chunk_id": id,
"content_ltks": chunk["content_ltks"],
"page_content": chunk["page_content"],
"document_id": did,
"docnm_kwd": dnm,
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i],
"vector": chunk.get(vector_column, zero_vector),
"positions": position_int,
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
if highlight and sres.highlight:
if id in sres.highlight:
d["highlight"] = remove_redundant_spaces(sres.highlight[id])
else:
d["highlight"] = d["page_content"]
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"document_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k,
"document_id": v["document_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size]
return ranks
def sql_retrieval(self, sql, fetch_size=128, format="json"):
tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl
def chunk_list(self, document_id: str, workspace_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "page_content", "img_id"],
sort_by_position: bool = False):
condition = {"document_id": document_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(workspace_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs:
break
return res
def all_tags(self, workspace_id: str, kb_ids: list[str], S=1000):
if not self.dataStore.indexExist(index_name(workspace_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(workspace_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
def all_tags_in_portion(self, workspace_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(workspace_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
def tag_content(self, workspace_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(workspace_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
def tag_query(self, question: str, workspace_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
if isinstance(workspace_ids, str):
idx_nms = index_name(workspace_ids)
else:
idx_nms = [index_name(workspace_id) for workspace_id in workspace_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query: str, chunks: list[dict], workspace_ids: list[str], chat_mdl, topn: int = 6):
if not chunks:
return []
idx_nms = [index_name(workspace_id) for workspace_id in workspace_ids]
ranks, document_id2kb_id = {}, {}
for ck in chunks:
if ck["document_id"] not in ranks:
ranks[ck["document_id"]] = 0
ranks[ck["document_id"]] += ck["similarity"]
document_id2kb_id[ck["document_id"]] = ck["kb_id"]
document_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [document_id2kb_id[document_id]]
es_res = self.dataStore.search(["page_content"], [], {"document_id": document_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["page_content"])
for _, doc in dict_chunks.items():
try:
toc.extend(json.loads(doc["page_content"]))
except Exception as e:
logging.exception(e)
if not toc:
return chunks
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)
if not ids:
return chunks
vector_size = 1024
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
for cid, sim in ids:
if cid in id2idx:
chunks[id2idx[cid]]["similarity"] += sim
continue
chunk = self.dataStore.get(cid, idx_nms, kb_ids)
d = {
"chunk_id": cid,
"content_ltks": chunk["content_ltks"],
"page_content": chunk["page_content"],
"document_id": document_id,
"docnm_kwd": chunk.get("docnm_kwd", ""),
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim,
"vector_similarity": sim,
"term_similarity": sim,
"vector": [0.0] * vector_size,
"positions": chunk.get("position_int", []),
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
for k in chunk.keys():
if k[-4:] == "_vec":
d["vector"] = chunk[k]
vector_size = len(chunk[k])
break
chunks.append(d)
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]

View File

@@ -91,6 +91,7 @@ QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt")
GRAPH_ENTITY_TYPES_PROMPT_TEMPLATE = load_prompt("graph_entity_types")
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
@@ -144,6 +145,21 @@ def question_proposal(chat_mdl, content, topn=3):
return kwd
def graph_entity_types(chat_mdl, scenario):
template = PROMPT_JINJA_ENV.from_string(GRAPH_ENTITY_TYPES_PROMPT_TEMPLATE)
rendered_prompt = template.render(scenario=scenario)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def full_question(messages=[], language=None, chat_mdl=None):
conv = []
for m in messages:

View File

@@ -0,0 +1,49 @@
## Role
You are a knowledge graph entity type identifier.
## Task
Identify and extract all relevant entity types for constructing a knowledge graph based on a given scenario.
## Requirements
- Analyze the scenario and determine key entity categories (e.g., person, organization, location, event, concept).
- Return all applicable entity types as an English comma-delimited list (no duplicates).
- Entity types must be in lowercase and use underscores for multi-word terms (e.g., "movie_genre").
- Output only the entity types, no explanations or additional text.
---
## Examples
### Example 1
**Scenario:**
A knowledge base about historical battles, including commanders, armies, locations, and outcomes.
**Output:**
person, military_commander, army, location, battle_event, outcome, date
---
### Example 2
**Scenario:**
A system tracking scientific research papers, including authors, institutions, fields of study, and citations.
**Output:**
person, author, research_institution, academic_field, research_paper, citation, publication_date
---
### Example 3
**Scenario:**
A travel guide for cities, covering landmarks, restaurants, hotels, and local events.
**Output:**
city, landmark, restaurant, hotel, local_event, cuisine_type, tourist_attraction
---
## Real Data
**Scenario:**
{{ scenario }}

View File

@@ -0,0 +1,634 @@
import logging
import re
import json
import time
import os
from urllib.parse import urlparse
import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from app.core.rag.common.decorator import singleton
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.common.misc_utils import convert_bytes
from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from app.core.rag.nlp import is_english, rag_tokenizer
from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
logger = logging.getLogger('rag.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f'Use Elasticsearch {os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")} as the doc engine.')
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f'{str(e)}. Waiting Elasticsearch {os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")} to be healthy.')
time.sleep(5)
if not self.es.ping():
msg = f'Elasticsearch {os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")} is unhealthy in 120s.'
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.0.0"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "app/core/rag/res", "mapping.json")
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f'Elasticsearch {os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")} is healthy.')
def _connect(self):
# Regular Elasticsearch configuration
parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "")
if parsed_url.scheme in {"http", "https"}:
hosts = f'{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
use_https = parsed_url.scheme == "https"
else:
hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
use_https = False
client_config = {
"hosts": [hosts],
"basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")),
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
}
# Only add SSL settings if using HTTPS
if use_https:
client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true"
if os.getenv("ELASTICSEARCH_CA_CERTS"):
client_config["ca_certs"] = str(os.getenv("ELASTICSEARCH_CA_CERTS"))
self.es = Elasticsearch(**client_config)
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
logger.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this workspace are in one index.
return
try:
self.es.indices.delete(index=indexName, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(e)
break
return False
"""
CRUD operations
"""
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
for field in highlightFields:
s = s.highlight(field)
if orderBy:
orders = list()
for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
raise e
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except NotFoundError:
return None
except Exception as e:
logger.exception(f"ESConnection.get({chunkId}) got exception")
raise e
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebaseId
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=(indexName), operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
doc.pop("id", None)
condition["kb_id"] = knowledgebaseId
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
except Exception:
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
return True
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
break
return False
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
assert "_id" not in condition
condition["kb_id"] = knowledgebaseId
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
if not chunk_ids: # when chunk_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=chunk_ids)
else:
qry = Q("bool")
for k, v in condition.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=indexName,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if n == "available_int" and isinstance(v, (int, float)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = remove_redundant_spaces(m[n])
if m:
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getAggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception:
logger.exception("ESConnection.sql got exception")
break
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None

View File

@@ -0,0 +1,382 @@
import logging
import json
import uuid
import valkey as redis
from app.core.rag.common.decorator import singleton
from valkey.lock import Lock
import trio
from app.core.config import settings as config_settings
redis_conn_params = {
"host": config_settings.REDIS_HOST,
"port": config_settings.REDIS_PORT,
"db": config_settings.REDIS_DB,
"password": config_settings.REDIS_PASSWORD,
"decode_responses": True,
"max_connections": 30,
}
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
self.__queue_name = queue_name
self.__group_name = group_name
self.__msg_id = msg_id
self.__message = json.loads(message["message"])
def ack(self):
try:
self.__consumer.xack(self.__queue_name, self.__group_name, self.__msg_id)
return True
except Exception as e:
logging.warning("[EXCEPTION]ack" + str(self.__queue_name) + "||" + str(e))
return False
def get_message(self):
return self.__message
def get_msg_id(self):
return self.__msg_id
@singleton
class RedisDB:
lua_delete_if_equal = None
LUA_DELETE_IF_EQUAL_SCRIPT = """
local current_value = redis.call('get', KEYS[1])
if current_value and current_value == ARGV[1] then
redis.call('del', KEYS[1])
return 1
end
return 0
"""
def __init__(self):
self.REDIS = None
self.__open__()
def __open__(self):
try:
self.REDIS = redis.StrictRedis(**redis_conn_params)
self.register_scripts()
except Exception as e:
logging.warning(f"Redis can't be connected. Error: {str(e)}")
return self.REDIS
def register_scripts(self) -> None:
cls = self.__class__
client = self.REDIS
cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
def health(self):
self.REDIS.ping()
a, b = "xx", "yy"
self.REDIS.set(a, b, 3)
if self.REDIS.get(a) == b:
return True
def info(self):
info = self.REDIS.info()
return {
'redis_version': info["redis_version"],
'server_mode': info["server_mode"],
'used_memory': info["used_memory_human"],
'total_system_memory': info["total_system_memory_human"],
'mem_fragmentation_ratio': info["mem_fragmentation_ratio"],
'connected_clients': info["connected_clients"],
'blocked_clients': info["blocked_clients"],
'instantaneous_ops_per_sec': info["instantaneous_ops_per_sec"],
'total_commands_processed': info["total_commands_processed"]
}
def is_alive(self):
return self.REDIS is not None
def exist(self, k):
if not self.REDIS:
return
try:
return self.REDIS.exists(k)
except Exception as e:
logging.warning("RedisDB.exist " + str(k) + " got exception: " + str(e))
self.__open__()
def get(self, k):
if not self.REDIS:
return
try:
return self.REDIS.get(k)
except Exception as e:
logging.warning("RedisDB.get " + str(k) + " got exception: " + str(e))
self.__open__()
def set_obj(self, k, obj, exp=3600):
try:
self.REDIS.set(k, json.dumps(obj, ensure_ascii=False), exp)
return True
except Exception as e:
logging.warning("RedisDB.set_obj " + str(k) + " got exception: " + str(e))
self.__open__()
return False
def set(self, k, v, exp=3600):
try:
self.REDIS.set(k, v, exp)
return True
except Exception as e:
logging.warning("RedisDB.set " + str(k) + " got exception: " + str(e))
self.__open__()
return False
def sadd(self, key: str, member: str):
try:
self.REDIS.sadd(key, member)
return True
except Exception as e:
logging.warning("RedisDB.sadd " + str(key) + " got exception: " + str(e))
self.__open__()
return False
def srem(self, key: str, member: str):
try:
self.REDIS.srem(key, member)
return True
except Exception as e:
logging.warning("RedisDB.srem " + str(key) + " got exception: " + str(e))
self.__open__()
return False
def smembers(self, key: str):
try:
res = self.REDIS.smembers(key)
return res
except Exception as e:
logging.warning(
"RedisDB.smembers " + str(key) + " got exception: " + str(e)
)
self.__open__()
return None
def zadd(self, key: str, member: str, score: float):
try:
self.REDIS.zadd(key, {member: score})
return True
except Exception as e:
logging.warning("RedisDB.zadd " + str(key) + " got exception: " + str(e))
self.__open__()
return False
def zcount(self, key: str, min: float, max: float):
try:
res = self.REDIS.zcount(key, min, max)
return res
except Exception as e:
logging.warning("RedisDB.zcount " + str(key) + " got exception: " + str(e))
self.__open__()
return 0
def zpopmin(self, key: str, count: int):
try:
res = self.REDIS.zpopmin(key, count)
return res
except Exception as e:
logging.warning("RedisDB.zpopmin " + str(key) + " got exception: " + str(e))
self.__open__()
return None
def zrangebyscore(self, key: str, min: float, max: float):
try:
res = self.REDIS.zrangebyscore(key, min, max)
return res
except Exception as e:
logging.warning(
"RedisDB.zrangebyscore " + str(key) + " got exception: " + str(e)
)
self.__open__()
return None
def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)
pipeline.set(key, value, exp, nx=True)
pipeline.execute()
return True
except Exception as e:
logging.warning(
"RedisDB.transaction " + str(key) + " got exception: " + str(e)
)
self.__open__()
return False
def queue_product(self, queue, message) -> bool:
for _ in range(3):
try:
payload = {"message": json.dumps(message)}
self.REDIS.xadd(queue, payload)
return True
except Exception as e:
logging.exception(
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
)
self.__open__()
return False
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
"""https://redis.io/docs/latest/commands/xreadgroup/"""
for _ in range(3):
try:
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(gi["name"] == group_name for gi in group_info):
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
except redis.exceptions.ResponseError as e:
if "no such key" in str(e).lower():
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
elif "busygroup" in str(e).lower():
logging.warning("Group already exists, continue.")
pass
else:
raise
args = {
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": 5,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res
except Exception as e:
if str(e) == 'no such key':
pass
else:
logging.exception(
"RedisDB.queue_consumer "
+ str(queue_name)
+ " got exception: "
+ str(e)
)
self.__open__()
return None
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
try:
for queue_name in queue_names:
try:
group_info = self.REDIS.xinfo_groups(queue_name)
except Exception as e:
if str(e) == 'no such key':
logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} doesn't exist")
continue
if not any(gi["name"] == group_name for gi in group_info):
logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} group {group_name} doesn't exist")
continue
current_min = 0
while True:
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
if not payload:
break
current_min = payload.get_msg_id()
logging.info(f"RedisDB.get_unacked_iterator {queue_name} {consumer_name} {current_min}")
yield payload
except Exception:
logging.exception(
"RedisDB.get_unacked_iterator got exception: "
)
self.__open__()
def get_pending_msg(self, queue, group_name):
try:
messages = self.REDIS.xpending_range(queue, group_name, '-', '+', 10)
return messages
except Exception as e:
if 'No such key' not in (str(e) or ''):
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
return []
def requeue_msg(self, queue: str, group_name: str, msg_id: str):
for _ in range(3):
try:
messages = self.REDIS.xrange(queue, msg_id, msg_id)
if messages:
self.REDIS.xadd(queue, messages[0][1])
self.REDIS.xack(queue, group_name, msg_id)
except Exception as e:
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
self.__open__()
def queue_info(self, queue, group_name) -> dict | None:
for _ in range(3):
try:
groups = self.REDIS.xinfo_groups(queue)
for group in groups:
if group["name"] == group_name:
return group
except Exception as e:
logging.warning(
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
)
self.__open__()
return None
def delete_if_equal(self, key: str, expected_value: str) -> bool:
"""
Do following atomically:
Delete a key if its value is equals to the given one, do nothing otherwise.
"""
return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS))
def delete(self, key) -> bool:
try:
self.REDIS.delete(key)
return True
except Exception as e:
logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e))
self.__open__()
return False
REDIS_CONN = RedisDB()
class RedisDistributedLock:
def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
self.lock_key = lock_key
if lock_value:
self.lock_value = lock_value
else:
self.lock_value = str(uuid.uuid4())
self.timeout = timeout
self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)
def acquire(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
return self.lock.acquire(token=self.lock_value)
async def spin_acquire(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
while True:
if self.lock.acquire(token=self.lock_value):
break
await trio.sleep(10)
def release(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)

View File

@@ -31,9 +31,9 @@ class Document(Base):
"person",
"geo",
"event",
"category",
"category"
],
"method": "general",
"method": "general"
}
}, comment="default parser config")
chunk_num = Column(Integer, default=0, comment="chunk num")

View File

@@ -70,9 +70,9 @@ class Knowledge(Base):
"person",
"geo",
"event",
"category",
"category"
],
"method": "general",
"method": "general"
}
},
comment="default parser config")

View File

@@ -1,4 +1,5 @@
import asyncio
import trio
import json
import os
import time
@@ -17,8 +18,10 @@ from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
@@ -52,138 +55,325 @@ def parse_document(file_path: str, document_id: uuid.UUID):
"""
Document parsing, vectorization, and storage
"""
with get_db_context() as db:
db_document = None
db_knowledge = None
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
try:
db_document = db.query(Document).filter(Document.id == document_id).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
# 1. Document parsing & segmentation
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
start_time = time.time()
db_document.progress = 0.0
db_document.progress_msg = progress_msg
db_document.process_begin_at = datetime.now(tz=timezone.utc)
db_document.process_duration = 0.0
db_document.run = 1
db.commit()
db.refresh(db_document)
db = next(get_db()) # Manually call the generator
db_document = None
db_knowledge = None
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
try:
db_document = db.query(Document).filter(Document.id == document_id).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
# 1. Document parsing & segmentation
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
start_time = time.time()
db_document.progress = 0.0
db_document.progress_msg = progress_msg
db_document.process_begin_at = datetime.now(tz=timezone.utc)
db_document.process_duration = 0.0
db_document.run = 1
db.commit()
db.refresh(db_document)
def progress_callback(prog=None, msg=None):
nonlocal progress_msg # Declare the use of an external progress_msg variable
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
# Prepare to configure chat_mdl、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
def progress_callback(prog=None, msg=None):
nonlocal progress_msg # Declare the use of an external progress_msg variable
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base
)
if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path,
re.IGNORECASE):
vision_model = QWenSeq2txt(
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
)
if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path, re.IGNORECASE):
vision_model = QWenSeq2txt(
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
lang="Chinese",
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
)
elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path, re.IGNORECASE):
vision_model = QWenCV(
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
lang="Chinese",
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
)
else:
print(file_path)
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
from_page=0,
to_page=100000,
callback=progress_callback,
vision_model=vision_model,
parser_config=db_document.parser_config,
is_root=False)
elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path,
re.IGNORECASE):
vision_model = QWenCV(
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
lang="Chinese",
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
)
else:
print(file_path)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
db_document.progress = 0.8
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
from_page=0,
to_page=100000,
callback=progress_callback,
vision_model=vision_model,
parser_config=db_document.parser_config,
is_root=False)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
db_document.progress = 0.8
db_document.progress_msg = progress_msg
db.commit()
db.refresh(db_document)
# 2. Document vectorization and storage
total_chunks = len(res)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
batch_size = 100
total_batches = ceil(total_chunks / batch_size)
progress_per_batch = 0.2 / total_batches # Progress of each batch
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2.1 Delete document vector index
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
for batch_start in range(0, total_chunks, batch_size):
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
batch = res[batch_start: batch_end] # Retrieve the current batch
chunks = []
# Process the current batch
for idx_in_batch, item in enumerate(batch):
global_idx = batch_start + idx_in_batch # Calculate global index
metadata = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"status": 1,
}
if db_document.parser_config.get("auto_questions", 0):
topn = db_document.parser_config["auto_questions"]
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question",
{"topn": topn})
if not cached:
cached = question_proposal(chat_model, item["content_with_weight"], topn)
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question",
{"topn": topn})
chunks.append(
DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}",
metadata=metadata))
else:
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
# Bulk segmented vector import
vector_service.add_chunks(chunks)
# Update progress
db_document.progress += progress_per_batch
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
db_document.progress_msg = progress_msg
db.commit()
db.refresh(db_document)
# 2. Document vectorization and storage
total_chunks = len(res)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
batch_size = 100
total_batches = ceil(total_chunks / batch_size)
progress_per_batch = 0.2 / total_batches # Progress of each batch
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2.1 Delete document vector index
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
for batch_start in range(0, total_chunks, batch_size):
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
batch = res[batch_start: batch_end] # Retrieve the current batch
chunks = []
# Process the current batch
for idx_in_batch, item in enumerate(batch):
global_idx = batch_start + idx_in_batch # Calculate global index
metadata = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"status": 1,
}
if db_document.parser_config.get("auto_questions", 0):
topn = db_document.parser_config["auto_questions"]
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = question_proposal(chat_model, item["content_with_weight"], topn)
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn})
chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata))
else:
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
# Bulk segmented vector import
vector_service.add_chunks(chunks)
# Update progress
db_document.progress += progress_per_batch
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
db_document.progress_msg = progress_msg
db_document.process_duration = time.time() - start_time
db_document.run = 0
db.commit()
db.refresh(db_document)
# Vectorization and data entry completed
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
db_document.chunk_num = total_chunks
db_document.progress = 1.0
db_document.process_duration = time.time() - start_time
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
db_document.progress_msg = progress_msg
db_document.run = 0
db.commit()
result = f"parse document '{db_document.file_name}' processed successfully."
return result
except Exception as e:
if 'db_document' in locals():
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
db_document.run = 0
db.commit()
result = f"parse document '{db_document.file_name}' failed."
return result
db.refresh(db_document)
# Vectorization and data entry completed
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
db_document.chunk_num = total_chunks
db_document.progress = 1.0
db_document.process_duration = time.time() - start_time
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
db_document.progress_msg = progress_msg
db_document.run = 0
db.commit()
# using graphrag
if db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False):
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
def callback(msg=None):
nonlocal progress_msg
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {msg}.\n"
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to run graphrag.\n"
start_time = time.time()
db_document.progress_msg = progress_msg
db.commit()
db.refresh(db_document)
task = {
"id": str(db_document.id),
"workspace_id": str(db_knowledge.workspace_id),
"kb_id": str(db_knowledge.id),
"parser_config": db_knowledge.parser_config,
}
# init_graphrag
vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0])
init_graphrag(task, vector_size)
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
chat_model, embedding_model, callback, with_resolution: bool = True,
with_community: bool = True, ) -> dict:
nonlocal progress_msg # Declare the use of an external progress_msg variable
result = await run_graphrag_for_kb(
row=row,
document_ids=document_ids,
language=language,
parser_config=parser_config,
vector_service=vector_service,
chat_model=chat_model,
embedding_model=embedding_model,
callback=callback,
with_resolution=with_resolution,
with_community=with_community,
)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
return result
try:
trio.run(
lambda: _run(
row=task,
document_ids=[str(db_document.id)],
language="Chinese",
parser_config=db_knowledge.parser_config,
vector_service=vector_service,
chat_model=chat_model,
embedding_model=embedding_model,
callback=callback,
with_resolution=with_resolution,
with_community=with_community,
)
)
except Exception as e:
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
db_document.progress_msg = progress_msg
db.commit()
db.refresh(db_document)
result = f"parse document '{db_document.file_name}' processed successfully."
return result
except Exception as e:
if 'db_document' in locals():
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
db_document.run = 0
db.commit()
result = f"parse document '{db_document.file_name}' failed."
return result
finally:
db.close()
@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb")
def build_graphrag_for_kb(kb_id: uuid.UUID):
"""
build knowledge graph
"""
db = next(get_db()) # Manually call the generator
db_knowledge = None
try:
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
# 1. Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base
)
# 2. get all document_ids from knowledge base
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.search_by_segment(document_id=None, query=None, pagesize=9999, page=1, asc=True)
document_ids = [item.metadata["document_id"] for item in items]
# 2. using graphrag
if db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False):
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
def callback(msg=None):
print(f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {msg}.\n")
start_time = time.time()
task = {
"id": str(db_knowledge.id),
"workspace_id": str(db_knowledge.workspace_id),
"kb_id": str(db_knowledge.id),
"parser_config": db_knowledge.parser_config,
}
# init_graphrag
vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0])
init_graphrag(task, vector_size)
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
chat_model, embedding_model, callback, with_resolution: bool = True,
with_community: bool = True, ) -> dict:
result = await run_graphrag_for_kb(
row=row,
document_ids=document_ids,
language=language,
parser_config=parser_config,
vector_service=vector_service,
chat_model=chat_model,
embedding_model=embedding_model,
callback=callback,
with_resolution=with_resolution,
with_community=with_community,
)
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n")
return result
try:
trio.run(
lambda: _run(
row=task,
document_ids=document_ids,
language="Chinese",
parser_config=db_knowledge.parser_config,
vector_service=vector_service,
chat_model=chat_model,
embedding_model=embedding_model,
callback=callback,
with_resolution=with_resolution,
with_community=with_community,
)
)
except Exception as e:
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n")
print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)")
result = f"build knowledge graph '{db_knowledge.name}' processed successfully."
return result
except Exception as e:
if 'db_knowledge' in locals():
print(f"Failed to build knowledge grap:{str(e)}\n")
result = f"build knowledge grap '{db_knowledge.name}' failed."
return result
finally:
db.close()
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)

View File

@@ -60,7 +60,9 @@ dependencies = [
"wcwidth==0.2.14",
"websockets==15.0.1",
"requests==2.32.5",
"elastic-transport==8.17.0",
"elasticsearch==8.17.0",
"elasticsearch-dsl==8.17.0",
"xinference-client==1.11.0",
"langchain-ollama",
"chardet==5.2.0",
@@ -128,6 +130,11 @@ dependencies = [
"celery>=5.5.2",
"simpleeval>=1.0.3",
"langchain-aws>=1.0.0a1",
"networkx>=3.4.2",
"editdistance==0.8.1",
"graspologic>=3.4.1,<4.0.0",
"markdown-to-json==2.1.1",
"valkey==6.0.2",
]
[tool.pytest.ini_options]

View File

@@ -53,7 +53,9 @@ watchfiles==1.1.1
wcwidth==0.2.14
websockets==15.0.1
requests==2.32.5
elastic-transport==8.17.0
elasticsearch==8.17.0
elasticsearch-dsl==8.17.0
xinference-client==1.11.0
langchain-ollama
chardet==5.2.0
@@ -122,3 +124,8 @@ pytest-asyncio>=1.3.0
uvicorn>=0.34.0
celery>=5.5.2
simpleeval>=1.0.3
networkx>=3.4.2
editdistance==0.8.1
graspologic>=3.4.1,<4.0.0
markdown-to-json==2.1.1
valkey==6.0.2