[modify] rag qa chunk

This commit is contained in:
Mark
2026-04-28 14:04:36 +08:00
parent 4bef9b578b
commit 140311048a
8 changed files with 279 additions and 110 deletions

View File

@@ -46,7 +46,10 @@ async def run_graphrag(
start = trio.current_time()
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
chunks = []
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
# 跳过 QA chunks只用原文 chunks 构建图谱
if d.get("chunk_type") == "qa":
continue
chunks.append(d["page_content"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@@ -150,6 +153,9 @@ async def run_graphrag_for_kb(
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
for doc in items:
# 跳过 QA chunks只用原文 chunks 构建图谱
if (doc.metadata or {}).get("chunk_type") == "qa":
continue
content = doc.page_content
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content

View File

@@ -131,18 +131,43 @@ def keyword_extraction(chat_mdl, content, topn=3):
def question_proposal(chat_mdl, content, topn=3):
"""生成问题(向后兼容,返回纯文本问题列表)"""
pairs = qa_proposal(chat_mdl, content, topn)
if not pairs:
return ""
return "\n".join([p["question"] for p in pairs])
def qa_proposal(chat_mdl, content, topn=3):
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]"""
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
raw = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(raw, tuple):
raw = raw[0]
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
if raw.find("**ERROR**") >= 0:
return []
return parse_qa_pairs(raw)
def parse_qa_pairs(text: str) -> list:
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
pairs = []
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
# 匹配 Q: ... A: ... 格式
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
if match:
q, a = match.group(1).strip(), match.group(2).strip()
if q and a:
pairs.append({"question": q, "answer": a})
return pairs
def graph_entity_types(chat_mdl, scenario):

View File

@@ -1,19 +1,24 @@
## Role
You are a text analyzer.
You are a text analyzer and knowledge extraction expert.
## Task
Propose {{ topn }} questions about a given piece of text content.
Generate {{ topn }} question-answer pairs from the given text content.
## Requirements
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
- Understand and summarize the text content, and generate the top {{ topn }} important question-answer pairs.
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in the same language as the given piece of text content.
- One question per line.
- Output questions ONLY.
- The answers MUST be concise, accurate, and directly derived from the text content.
- The answers SHOULD be self-contained and understandable without additional context.
- Both questions and answers MUST be in the same language as the given text content.
- Output question-answer pairs ONLY, no extra explanation.
## Example Output
Q: What is the capital of France? A: The capital of France is Paris.
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.
---
## Text Content
{{ content }}

View File

@@ -53,13 +53,30 @@ class ElasticSearchVector(BaseVector):
return "elasticsearch"
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks]
# QA chunks: embedding 只对 question 字段做source chunks: 不做 embedding
texts_for_embedding = []
for chunk in chunks:
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
if chunk_type == "source":
# source chunk 不需要向量索引
texts_for_embedding.append("")
elif chunk_type == "qa":
# QA chunk: 用 question 字段做 embedding
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
else:
# 普通 chunk: 用 page_content 做 embedding
texts_for_embedding.append(chunk.page_content)
if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
embeddings = self.embeddings.embed_batch(texts_for_embedding)
else:
embeddings = self.embeddings.embed_documents(list(texts))
embeddings = self.embeddings.embed_documents(texts_for_embedding)
# source chunk 的向量置空
for i, chunk in enumerate(chunks):
if (chunk.metadata or {}).get("chunk_type") == "source":
embeddings[i] = None
self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -72,13 +89,25 @@ class ElasticSearchVector(BaseVector):
uuids = self._get_uuids(chunks)
actions = []
for i, chunk in enumerate(chunks):
source = {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
# 写入 QA 相关字段
meta = chunk.metadata or {}
if meta.get("chunk_type"):
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
if meta.get("question"):
source[Field.QUESTION.value] = meta["question"]
if meta.get("answer"):
source[Field.ANSWER.value] = meta["answer"]
if meta.get("source_chunk_id"):
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
action = {
"_index": self._collection_name,
"_source": {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
"_source": source
}
actions.append(action)
# using bulk mode
@@ -241,10 +270,19 @@ class ElasticSearchVector(BaseVector):
for res in result["hits"]["hits"]:
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
# vector = source.get(Field.VECTOR.value)
vector = None
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"]
# 将 QA 字段注入 metadata 供前端展示
if chunk_type:
metadata["chunk_type"] = chunk_type
if chunk_type == "qa":
metadata["question"] = source.get(Field.QUESTION.value, "")
metadata["answer"] = source.get(Field.ANSWER.value, "")
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
docs = []
@@ -308,27 +346,43 @@ class ElasticSearchVector(BaseVector):
Returns:
updated count.
"""
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
indices = kwargs.get("indices", self._collection_name)
chunk_type = (chunk.metadata or {}).get("chunk_type")
# QA chunk: embedding 基于 questionsource chunk: 不更新向量
if chunk_type == "source":
embed_text = ""
elif chunk_type == "qa":
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
embed_text = chunk.page_content
if chunk_type != "source":
if self.is_multimodal_embedding:
chunk.vector = self.embeddings.embed_text(embed_text)
else:
chunk.vector = self.embeddings.embed_query(embed_text)
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
params = {
"new_content": chunk.page_content,
"new_vector": chunk.vector if chunk_type != "source" else None
}
# QA chunk: 同时更新 question/answer 字段
if chunk_type == "qa":
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
params["new_question"] = (chunk.metadata or {}).get("question", "")
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
body = {
"script": {
"source": """
ctx._source.page_content = params.new_content;
ctx._source.vector = params.new_vector;
""",
"params": {
"new_content": chunk.page_content,
"new_vector": chunk.vector
}
"source": script_source,
"params": params
},
"query": {
"term": {
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
Field.DOC_ID.value: chunk.metadata["doc_id"]
}
}
}
@@ -336,9 +390,6 @@ class ElasticSearchVector(BaseVector):
index=indices,
body=body,
)
# Remove debug printing and use logging instead
# print(result)
# print(f"Update successful, number of affected documents: {result['updated']}")
return result['updated']
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
@@ -397,11 +448,11 @@ class ElasticSearchVector(BaseVector):
}
}
},
"filter": { # Add the filter condition of status=1
"term": {
"metadata.status": 1
}
}
"filter": [
{"term": {"metadata.status": 1}},
# 排除 source chunk仅供 GraphRAG 使用,不参与检索)
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
]
}
}
# If file_names_filter is passed in, merge the filtering conditions
@@ -415,22 +466,14 @@ class ElasticSearchVector(BaseVector):
},
"script": {
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
"params": {"query_vector": query_vector}
}
}
},
"filter": [
{
"term": {
"metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
{"term": {"metadata.status": 1}},
{"terms": {"metadata.file_name": file_names_filter}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
],
}
}
@@ -451,8 +494,19 @@ class ElasticSearchVector(BaseVector):
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"]
score = score / 2 # Normalized [0-1]
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
docs = []
@@ -491,11 +545,10 @@ class ElasticSearchVector(BaseVector):
}
}
},
"filter": { # Add the filter condition of status=1
"term": {
"metadata.status": 1
}
}
"filter": [
{"term": {"metadata.status": 1}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
]
}
}
@@ -512,16 +565,9 @@ class ElasticSearchVector(BaseVector):
}
},
"filter": [
{
"term": {
"metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
{"term": {"metadata.status": 1}},
{"terms": {"metadata.file_name": file_names_filter}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
],
}
}
@@ -543,6 +589,17 @@ class ElasticSearchVector(BaseVector):
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
# Normalize the score to the [0,1] interval
normalized_score = res["_score"] / max_score
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))

View File

@@ -14,3 +14,8 @@ class Field(StrEnum):
DOCUMENT_ID = "metadata.document_id"
KNOWLEDGE_ID = "metadata.knowledge_id"
SORT_ID = "metadata.sort_id"
# QA fields
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
QUESTION = "question"
ANSWER = "answer"
SOURCE_CHUNK_ID = "source_chunk_id"