From 140311048a1fc13acda98ff0ff1d8c2f0351cc2d Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 14:04:36 +0800 Subject: [PATCH 01/11] [modify] rag qa chunk --- api/app/controllers/chunk_controller.py | 6 + api/app/core/rag/graphrag/general/index.py | 8 +- api/app/core/rag/prompts/generator.py | 39 +++- api/app/core/rag/prompts/question_prompt.md | 19 +- .../vdb/elasticsearch/elasticsearch_vector.py | 175 ++++++++++++------ api/app/core/rag/vdb/field.py | 5 + api/app/schemas/chunk_schema.py | 42 ++++- api/app/tasks.py | 95 +++++++--- 8 files changed, 279 insertions(+), 110 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index e1fdaa89..fee26669 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -271,6 +271,9 @@ async def create_chunk( "sort_id": sort_id, "status": 1, } + # QA chunk: 注入 chunk_type/question/answer 到 metadata + if create_data.is_qa: + metadata.update(create_data.qa_metadata) chunk = DocumentChunk(page_content=content, metadata=metadata) # 3. Segmented vector storage vector_service.add_chunks([chunk]) @@ -342,6 +345,9 @@ async def update_chunk( if total: chunk = items[0] chunk.page_content = content + # QA chunk: 更新 metadata 中的 question/answer + if update_data.is_qa: + chunk.metadata.update(update_data.qa_metadata) vector_service.update_by_segment(chunk) return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated") else: diff --git a/api/app/core/rag/graphrag/general/index.py b/api/app/core/rag/graphrag/general/index.py index 1bd826ca..1f1ee756 100644 --- a/api/app/core/rag/graphrag/general/index.py +++ b/api/app/core/rag/graphrag/general/index.py @@ -46,7 +46,10 @@ async def run_graphrag( start = trio.current_time() workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"] chunks = [] - for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True): + for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True): + # 跳过 QA chunks,只用原文 chunks 构建图谱 + if d.get("chunk_type") == "qa": + continue chunks.append(d["page_content"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): @@ -150,6 +153,9 @@ async def run_graphrag_for_kb( total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True) for doc in items: + # 跳过 QA chunks,只用原文 chunks 构建图谱 + if (doc.metadata or {}).get("chunk_type") == "qa": + continue content = doc.page_content if num_tokens_from_string(current_chunk + content) < 1024: current_chunk += content diff --git a/api/app/core/rag/prompts/generator.py b/api/app/core/rag/prompts/generator.py index 642d0849..88e9171f 100644 --- a/api/app/core/rag/prompts/generator.py +++ b/api/app/core/rag/prompts/generator.py @@ -131,18 +131,43 @@ def keyword_extraction(chat_mdl, content, topn=3): def question_proposal(chat_mdl, content, topn=3): + """生成问题(向后兼容,返回纯文本问题列表)""" + pairs = qa_proposal(chat_mdl, content, topn) + if not pairs: + return "" + return "\n".join([p["question"] for p in pairs]) + + +def qa_proposal(chat_mdl, content, topn=3): + """生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]""" template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096)) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) - if isinstance(kwd, tuple): - kwd = kwd[0] - kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) - if kwd.find("**ERROR**") >= 0: - return "" - return kwd + raw = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) + if isinstance(raw, tuple): + raw = raw[0] + raw = re.sub(r"^.*", "", raw, flags=re.DOTALL) + if raw.find("**ERROR**") >= 0: + return [] + return parse_qa_pairs(raw) + + +def parse_qa_pairs(text: str) -> list: + """解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx""" + pairs = [] + for line in text.strip().split("\n"): + line = line.strip() + if not line: + continue + # 匹配 Q: ... A: ... 格式 + match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE) + if match: + q, a = match.group(1).strip(), match.group(2).strip() + if q and a: + pairs.append({"question": q, "answer": a}) + return pairs def graph_entity_types(chat_mdl, scenario): diff --git a/api/app/core/rag/prompts/question_prompt.md b/api/app/core/rag/prompts/question_prompt.md index ec9889fb..91e43d05 100644 --- a/api/app/core/rag/prompts/question_prompt.md +++ b/api/app/core/rag/prompts/question_prompt.md @@ -1,19 +1,24 @@ ## Role -You are a text analyzer. +You are a text analyzer and knowledge extraction expert. ## Task -Propose {{ topn }} questions about a given piece of text content. +Generate {{ topn }} question-answer pairs from the given text content. ## Requirements -- Understand and summarize the text content, and propose the top {{ topn }} important questions. +- Understand and summarize the text content, and generate the top {{ topn }} important question-answer pairs. +- Each question-answer pair MUST be on a single line, formatted as: Q: A: - The questions SHOULD NOT have overlapping meanings. - The questions SHOULD cover the main content of the text as much as possible. -- The questions MUST be in the same language as the given piece of text content. -- One question per line. -- Output questions ONLY. +- The answers MUST be concise, accurate, and directly derived from the text content. +- The answers SHOULD be self-contained and understandable without additional context. +- Both questions and answers MUST be in the same language as the given text content. +- Output question-answer pairs ONLY, no extra explanation. + +## Example Output +Q: What is the capital of France? A: The capital of France is Paris. +Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889. --- ## Text Content {{ content }} - diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index cc9ec120..3f64ad85 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -53,13 +53,30 @@ class ElasticSearchVector(BaseVector): return "elasticsearch" def add_chunks(self, chunks: list[DocumentChunk], **kwargs): - # 实现 Elasticsearch 保存向量 - texts = [chunk.page_content for chunk in chunks] + # QA chunks: embedding 只对 question 字段做;source chunks: 不做 embedding + texts_for_embedding = [] + for chunk in chunks: + chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk") + if chunk_type == "source": + # source chunk 不需要向量索引 + texts_for_embedding.append("") + elif chunk_type == "qa": + # QA chunk: 用 question 字段做 embedding + texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content)) + else: + # 普通 chunk: 用 page_content 做 embedding + texts_for_embedding.append(chunk.page_content) + if self.is_multimodal_embedding: - # 火山引擎多模态 Embedding - embeddings = self.embeddings.embed_batch(texts) + embeddings = self.embeddings.embed_batch(texts_for_embedding) else: - embeddings = self.embeddings.embed_documents(list(texts)) + embeddings = self.embeddings.embed_documents(texts_for_embedding) + + # source chunk 的向量置空 + for i, chunk in enumerate(chunks): + if (chunk.metadata or {}).get("chunk_type") == "source": + embeddings[i] = None + self.create(chunks, embeddings, **kwargs) def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): @@ -72,13 +89,25 @@ class ElasticSearchVector(BaseVector): uuids = self._get_uuids(chunks) actions = [] for i, chunk in enumerate(chunks): + source = { + Field.CONTENT_KEY.value: chunk.page_content, + Field.METADATA_KEY.value: chunk.metadata or {}, + Field.VECTOR.value: embeddings[i] or None + } + # 写入 QA 相关字段 + meta = chunk.metadata or {} + if meta.get("chunk_type"): + source[Field.CHUNK_TYPE.value] = meta["chunk_type"] + if meta.get("question"): + source[Field.QUESTION.value] = meta["question"] + if meta.get("answer"): + source[Field.ANSWER.value] = meta["answer"] + if meta.get("source_chunk_id"): + source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"] + action = { "_index": self._collection_name, - "_source": { - Field.CONTENT_KEY.value: chunk.page_content, - Field.METADATA_KEY.value: chunk.metadata or {}, - Field.VECTOR.value: embeddings[i] or None - } + "_source": source } actions.append(action) # using bulk mode @@ -241,10 +270,19 @@ class ElasticSearchVector(BaseVector): for res in result["hits"]["hits"]: source = res["_source"] page_content = source.get(Field.CONTENT_KEY.value) - # vector = source.get(Field.VECTOR.value) vector = None metadata = source.get(Field.METADATA_KEY.value, {}) + chunk_type = source.get(Field.CHUNK_TYPE.value) score = res["_score"] + + # 将 QA 字段注入 metadata 供前端展示 + if chunk_type: + metadata["chunk_type"] = chunk_type + if chunk_type == "qa": + metadata["question"] = source.get(Field.QUESTION.value, "") + metadata["answer"] = source.get(Field.ANSWER.value, "") + page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}" + docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score)) docs = [] @@ -308,27 +346,43 @@ class ElasticSearchVector(BaseVector): Returns: updated count. """ - indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - if self.is_multimodal_embedding: - # 火山引擎多模态 Embedding - chunk.vector = self.embeddings.embed_text(chunk.page_content) + indices = kwargs.get("indices", self._collection_name) + chunk_type = (chunk.metadata or {}).get("chunk_type") + + # QA chunk: embedding 基于 question;source chunk: 不更新向量 + if chunk_type == "source": + embed_text = "" + elif chunk_type == "qa": + embed_text = (chunk.metadata or {}).get("question", chunk.page_content) else: - chunk.vector = self.embeddings.embed_query(chunk.page_content) + embed_text = chunk.page_content + + if chunk_type != "source": + if self.is_multimodal_embedding: + chunk.vector = self.embeddings.embed_text(embed_text) + else: + chunk.vector = self.embeddings.embed_query(embed_text) + + script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;" + params = { + "new_content": chunk.page_content, + "new_vector": chunk.vector if chunk_type != "source" else None + } + + # QA chunk: 同时更新 question/answer 字段 + if chunk_type == "qa": + script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;" + params["new_question"] = (chunk.metadata or {}).get("question", "") + params["new_answer"] = (chunk.metadata or {}).get("answer", "") body = { "script": { - "source": """ - ctx._source.page_content = params.new_content; - ctx._source.vector = params.new_vector; - """, - "params": { - "new_content": chunk.page_content, - "new_vector": chunk.vector - } + "source": script_source, + "params": params }, "query": { "term": { - Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id + Field.DOC_ID.value: chunk.metadata["doc_id"] } } } @@ -336,9 +390,6 @@ class ElasticSearchVector(BaseVector): index=indices, body=body, ) - # Remove debug printing and use logging instead - # print(result) - # print(f"Update successful, number of affected documents: {result['updated']}") return result['updated'] def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str: @@ -397,11 +448,11 @@ class ElasticSearchVector(BaseVector): } } }, - "filter": { # Add the filter condition of status=1 - "term": { - "metadata.status": 1 - } - } + "filter": [ + {"term": {"metadata.status": 1}}, + # 排除 source chunk(仅供 GraphRAG 使用,不参与检索) + {"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} + ] } } # If file_names_filter is passed in, merge the filtering conditions @@ -415,22 +466,14 @@ class ElasticSearchVector(BaseVector): }, "script": { "source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0", - # The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1] "params": {"query_vector": query_vector} } } }, "filter": [ - { - "term": { - "metadata.status": 1 - } - }, - { - "terms": { - "metadata.file_name": file_names_filter # Additional file_name filtering - } - } + {"term": {"metadata.status": 1}}, + {"terms": {"metadata.file_name": file_names_filter}}, + {"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} ], } } @@ -451,8 +494,19 @@ class ElasticSearchVector(BaseVector): source = res["_source"] page_content = source.get(Field.CONTENT_KEY.value) metadata = source.get(Field.METADATA_KEY.value, {}) + chunk_type = source.get(Field.CHUNK_TYPE.value) score = res["_score"] score = score / 2 # Normalized [0-1] + + # QA chunk: 返回 Q+A 拼接作为上下文 + if chunk_type == "qa": + question = source.get(Field.QUESTION.value, "") + answer = source.get(Field.ANSWER.value, "") + page_content = f"Q: {question}\nA: {answer}" + metadata["chunk_type"] = "qa" + metadata["question"] = question + metadata["answer"] = answer + docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score)) docs = [] @@ -491,11 +545,10 @@ class ElasticSearchVector(BaseVector): } } }, - "filter": { # Add the filter condition of status=1 - "term": { - "metadata.status": 1 - } - } + "filter": [ + {"term": {"metadata.status": 1}}, + {"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} + ] } } @@ -512,16 +565,9 @@ class ElasticSearchVector(BaseVector): } }, "filter": [ - { - "term": { - "metadata.status": 1 - } - }, - { - "terms": { - "metadata.file_name": file_names_filter # Additional file_name filtering - } - } + {"term": {"metadata.status": 1}}, + {"terms": {"metadata.file_name": file_names_filter}}, + {"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} ], } } @@ -543,6 +589,17 @@ class ElasticSearchVector(BaseVector): source = res["_source"] page_content = source.get(Field.CONTENT_KEY.value) metadata = source.get(Field.METADATA_KEY.value, {}) + chunk_type = source.get(Field.CHUNK_TYPE.value) + + # QA chunk: 返回 Q+A 拼接作为上下文 + if chunk_type == "qa": + question = source.get(Field.QUESTION.value, "") + answer = source.get(Field.ANSWER.value, "") + page_content = f"Q: {question}\nA: {answer}" + metadata["chunk_type"] = "qa" + metadata["question"] = question + metadata["answer"] = answer + # Normalize the score to the [0,1] interval normalized_score = res["_score"] / max_score docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score)) diff --git a/api/app/core/rag/vdb/field.py b/api/app/core/rag/vdb/field.py index 99d872c2..5f20a21a 100644 --- a/api/app/core/rag/vdb/field.py +++ b/api/app/core/rag/vdb/field.py @@ -14,3 +14,8 @@ class Field(StrEnum): DOCUMENT_ID = "metadata.document_id" KNOWLEDGE_ID = "metadata.knowledge_id" SORT_ID = "metadata.sort_id" + # QA fields + CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa" + QUESTION = "question" + ANSWER = "answer" + SOURCE_CHUNK_ID = "source_chunk_id" diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index ce8f70f2..b888c361 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -20,13 +20,26 @@ class ChunkCreate(BaseModel): @property def chunk_content(self) -> str: - """ - Get the actual content string regardless of input type - """ + """Get the actual content string regardless of input type""" if isinstance(self.content, QAChunk): - return f"question: {self.content.question} answer: {self.content.answer}" + return self.content.question # QA 模式下 page_content 存 question return self.content + @property + def is_qa(self) -> bool: + return isinstance(self.content, QAChunk) + + @property + def qa_metadata(self) -> dict: + """返回 QA 相关的 metadata 字段""" + if isinstance(self.content, QAChunk): + return { + "chunk_type": "qa", + "question": self.content.question, + "answer": self.content.answer, + } + return {} + class ChunkUpdate(BaseModel): content: Union[str, QAChunk] = Field( @@ -35,13 +48,26 @@ class ChunkUpdate(BaseModel): @property def chunk_content(self) -> str: - """ - Get the actual content string regardless of input type - """ + """Get the actual content string regardless of input type""" if isinstance(self.content, QAChunk): - return f"question: {self.content.question} answer: {self.content.answer}" + return self.content.question # QA 模式下 page_content 存 question return self.content + @property + def is_qa(self) -> bool: + return isinstance(self.content, QAChunk) + + @property + def qa_metadata(self) -> dict: + """返回 QA 相关的 metadata 字段""" + if isinstance(self.content, QAChunk): + return { + "chunk_type": "qa", + "question": self.content.question, + "answer": self.content.answer, + } + return {} + class ChunkRetrieve(BaseModel): query: str diff --git a/api/app/tasks.py b/api/app/tasks.py index 3ad1a0dd..ed961115 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.sequence2txt_model import QWenSeq2txt from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.prompts.generator import question_proposal +from app.core.rag.prompts.generator import question_proposal, qa_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, ) @@ -323,57 +323,96 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): all_batch_chunks: list[list[DocumentChunk]] = [] if auto_questions_topn: - # auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组 - # 构建 (global_idx, item) 列表 + # QA 模式(FastGPT 方案): + # 1. 原 chunk 标记为 source(保留供 GraphRAG 使用,不参与检索) + # 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk indexed_items = list(enumerate(res)) - def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]: - """为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)""" + def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]: + """为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)""" global_idx, item = idx_item content = item["content_with_weight"] - cached = get_llm_cache(chat_model.model_name, content, "question", + cached = get_llm_cache(chat_model.model_name, content, "qa", {"topn": auto_questions_topn}) if not cached: - cached = question_proposal(chat_model, content, auto_questions_topn) - set_llm_cache(chat_model.model_name, content, cached, "question", + pairs = qa_proposal(chat_model, content, auto_questions_topn) + cached = pairs + set_llm_cache(chat_model.model_name, content, cached, "qa", {"topn": auto_questions_topn}) + elif isinstance(cached, str): + # 兼容旧缓存格式(纯文本问题) + from app.core.rag.prompts.generator import parse_qa_pairs + cached = parse_qa_pairs(cached) if cached else [] return global_idx, cached - # 并发调用 LLM 生成问题 - question_map: dict[int, str] = {} + # 并发调用 LLM 生成 QA 对 + qa_map: dict[int, list] = {} with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor: - futures = {q_executor.submit(_generate_question, item): item[0] + futures = {q_executor.submit(_generate_qa, item): item[0] for item in indexed_items} for future in futures: - global_idx, cached = future.result() - question_map[global_idx] = cached + global_idx, pairs = future.result() + qa_map[global_idx] = pairs progress_lines.append( - f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks " + f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks " f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).") - # 按 batch 分组组装 DocumentChunk - for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): - batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) - chunks = [] - for global_idx in range(batch_start, batch_end): - item = res[global_idx] - metadata = { + # 组装 chunks:source chunks + qa chunks + source_chunks = [] + qa_chunks = [] + qa_sort_id = 0 + + for global_idx in range(total_chunks): + item = res[global_idx] + source_chunk_id = uuid.uuid4().hex + + # source chunk:保留原文,供 GraphRAG 使用,不参与向量检索 + source_meta = { + "doc_id": source_chunk_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + "chunk_type": "source", + } + source_chunks.append( + DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta)) + + # qa chunks:每个 QA 对独立存储 + pairs = qa_map.get(global_idx, []) + for pair in pairs: + qa_meta = { "doc_id": uuid.uuid4().hex, "file_id": str(db_document.file_id), "file_name": db_document.file_name, "file_created_at": int(db_document.created_at.timestamp() * 1000), "document_id": str(db_document.id), "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, + "sort_id": qa_sort_id, "status": 1, + "chunk_type": "qa", + "question": pair["question"], + "answer": pair["answer"], + "source_chunk_id": source_chunk_id, } - cached = question_map[global_idx] - chunks.append( - DocumentChunk( - page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - all_batch_chunks.append(chunks) + # page_content 存 question,用于向量索引 + qa_chunks.append( + DocumentChunk(page_content=pair["question"], metadata=qa_meta)) + qa_sort_id += 1 + + # 按 batch 分组(source + qa 一起) + all_chunks = source_chunks + qa_chunks + for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks)) + all_batch_chunks.append(all_chunks[batch_start:batch_end]) + + progress_lines.append( + f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + " + f"{len(qa_chunks)} QA chunks prepared.") else: # 无 auto_questions:直接构建 chunks for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): From 64e640d882a82e83a90568f8cfc7ce31dfbae1c5 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 15:33:44 +0800 Subject: [PATCH 02/11] [add] batch chunk. qa_prompt set --- api/app/controllers/chunk_controller.py | 61 +++++++++++++++++++++++++ api/app/core/config.py | 1 + api/app/core/rag/prompts/generator.py | 16 +++++-- api/app/schemas/chunk_schema.py | 5 ++ api/app/tasks.py | 30 +++++++++--- 5 files changed, 103 insertions(+), 10 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index fee26669..07379ee4 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -285,6 +285,67 @@ async def create_chunk( return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful") +@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse) +async def create_chunks_batch( + kb_id: uuid.UUID, + document_id: uuid.UUID, + batch_data: chunk_schema.ChunkBatchCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + Batch create chunks (max 8) + """ + api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}") + + if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}" + ) + + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied") + + db_document = db.query(Document).filter(Document.id == document_id).first() + if not db_document: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it") + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # Get current max sort_id + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] + + chunks = [] + for create_data in batch_data.items: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(document_id), + "knowledge_id": str(kb_id), + "sort_id": sort_id, + "status": 1, + } + if create_data.is_qa: + metadata.update(create_data.qa_metadata) + chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata)) + + vector_service.add_chunks(chunks) + + db_document.chunk_num += len(chunks) + db.commit() + + return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") + + @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) async def get_chunk( kb_id: uuid.UUID, diff --git a/api/app/core/config.py b/api/app/core/config.py index 64c5520e..725c35ce 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -98,6 +98,7 @@ class Settings: # File Upload MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800")) MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20")) + MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8")) FILE_PATH: str = os.getenv("FILE_PATH", "/files") FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600")) diff --git a/api/app/core/rag/prompts/generator.py b/api/app/core/rag/prompts/generator.py index 88e9171f..4af8fa6f 100644 --- a/api/app/core/rag/prompts/generator.py +++ b/api/app/core/rag/prompts/generator.py @@ -138,9 +138,19 @@ def question_proposal(chat_mdl, content, topn=3): return "\n".join([p["question"] for p in pairs]) -def qa_proposal(chat_mdl, content, topn=3): - """生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]""" - template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) +def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None): + """生成 QA 对,返回 [{"question": ..., "answer": ...}, ...] + + Args: + chat_mdl: LLM 模型 + content: 文本内容 + topn: 生成 QA 对数量 + custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn) + """ + if custom_prompt: + template = PROMPT_JINJA_ENV.from_string(custom_prompt) + else: + template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index b888c361..9c785c5e 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -77,3 +77,8 @@ class ChunkRetrieve(BaseModel): vector_similarity_weight: float | None = Field(None) top_k: int | None = Field(None) retrieve_type: RetrieveType | None = Field(None) + + +class ChunkBatchCreate(BaseModel): + """批量创建 chunk""" + items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表") diff --git a/api/app/tasks.py b/api/app/tasks.py index ed961115..114959b6 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) # 2.2 Vectorize and import batch documents auto_questions_topn = db_document.parser_config.get("auto_questions", 0) + qa_prompt = db_document.parser_config.get("qa_prompt", None) chat_model = None if auto_questions_topn: chat_model = Base( @@ -318,6 +319,9 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): model_name=db_knowledge.llm.api_keys[0].model_name, base_url=db_knowledge.llm.api_keys[0].api_base, ) + logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}") + if qa_prompt: + logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)") # 预先构建所有 batch 的 chunks,保证 sort_id 全局有序 all_batch_chunks: list[list[DocumentChunk]] = [] @@ -335,15 +339,27 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): cached = get_llm_cache(chat_model.model_name, content, "qa", {"topn": auto_questions_topn}) if not cached: - pairs = qa_proposal(chat_model, content, auto_questions_topn) - cached = pairs - set_llm_cache(chat_model.model_name, content, cached, "qa", + try: + pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt) + except Exception as e: + logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}") + return global_idx, [] + # 缓存存 JSON 字符串 + set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa", {"topn": auto_questions_topn}) - elif isinstance(cached, str): - # 兼容旧缓存格式(纯文本问题) + return global_idx, pairs + # 从缓存读取:可能是 JSON 字符串或旧格式纯文本 + if isinstance(cached, str): + try: + parsed = json.loads(cached) + if isinstance(parsed, list): + return global_idx, parsed + except (json.JSONDecodeError, TypeError): + pass + # 旧缓存格式(纯文本问题),尝试解析 from app.core.rag.prompts.generator import parse_qa_pairs - cached = parse_qa_pairs(cached) if cached else [] - return global_idx, cached + return global_idx, parse_qa_pairs(cached) if cached else [] + return global_idx, cached if isinstance(cached, list) else [] # 并发调用 LLM 生成 QA 对 qa_map: dict[int, list] = {} From f667936664b0323b0b14b37c090c1332ec607597 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 15:53:07 +0800 Subject: [PATCH 03/11] [fix] qa cache --- api/app/tasks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index 114959b6..811677f7 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -336,8 +336,11 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): """为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)""" global_idx, item = idx_item content = item["content_with_weight"] - cached = get_llm_cache(chat_model.model_name, content, "qa", - {"topn": auto_questions_topn}) + cache_params = {"topn": auto_questions_topn} + if qa_prompt: + import hashlib + cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8] + cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params) if not cached: try: pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt) @@ -346,7 +349,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): return global_idx, [] # 缓存存 JSON 字符串 set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa", - {"topn": auto_questions_topn}) + cache_params) return global_idx, pairs # 从缓存读取:可能是 JSON 字符串或旧格式纯文本 if isinstance(cached, str): From 6c47bb77ab5353fa083611ac77a5b47ef5834ffe Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 16:13:26 +0800 Subject: [PATCH 04/11] [add] task log --- api/app/tasks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/app/tasks.py b/api/app/tasks.py index 811677f7..4d39cf7a 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -342,20 +342,24 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8] cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params) if not cached: + logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}") try: pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt) except Exception as e: logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}") return global_idx, [] + logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs") # 缓存存 JSON 字符串 set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa", cache_params) return global_idx, pairs + logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}") # 从缓存读取:可能是 JSON 字符串或旧格式纯文本 if isinstance(cached, str): try: parsed = json.loads(cached) if isinstance(parsed, list): + logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache") return global_idx, parsed except (json.JSONDecodeError, TypeError): pass From 90aa4cef21c74e353afb001a413bf3e87e68d5e8 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 16:38:14 +0800 Subject: [PATCH 05/11] [add] import qa chunks --- api/app/controllers/chunk_controller.py | 44 ++++++++- api/app/tasks.py | 117 ++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 07379ee4..cb927504 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -1,8 +1,10 @@ import os +import csv +import io from typing import Any, Optional import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session @@ -346,6 +348,46 @@ async def create_chunks_batch( return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") +@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse) +async def import_qa_chunks( + kb_id: uuid.UUID, + document_id: uuid.UUID, + file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 导入 QA 问答对(CSV/Excel),异步处理 + """ + api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}") + + # 1. 校验文件格式 + filename = file.filename or "" + if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式") + + # 2. 校验知识库和文档 + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问") + + db_document = db.query(Document).filter(Document.id == document_id).first() + if not db_document: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问") + + # 3. 读取文件内容,派发异步任务 + contents = await file.read() + + from app.celery_app import celery_app + task = celery_app.send_task( + "app.core.rag.tasks.import_qa_chunks", + args=[str(kb_id), str(document_id), filename, contents], + queue="qa_import" + ) + + return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中") + + @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) async def get_chunk( kb_id: uuid.UUID, diff --git a/api/app/tasks.py b/api/app/tasks.py index 4d39cf7a..2fcab818 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -697,6 +697,123 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str): return f"build_graphrag_for_document '{document_id}' failed: {e}" +@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import") +def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes): + """ + 异步导入 QA 问答对(CSV/Excel) + + 文件格式:第一行标题(跳过),第一列问题,第二列答案 + """ + import csv as csv_module + import io + + db = SessionLocal() + try: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() + if not db_document or not db_knowledge: + logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") + return {"error": "document or knowledge not found", "imported": 0} + + # 1. 解析文件 + qa_pairs = [] + failed_rows = [] + + if filename.endswith(".csv"): + try: + text = contents.decode("utf-8-sig") + except UnicodeDecodeError: + text = contents.decode("gbk", errors="ignore") + + sniffer = csv_module.Sniffer() + try: + dialect = sniffer.sniff(text[:2048]) + delimiter = dialect.delimiter + except csv_module.Error: + delimiter = "," if "," in text[:500] else "\t" + + reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) + for i, row in enumerate(reader): + if i == 0: + continue + if len(row) >= 2 and row[0].strip() and row[1].strip(): + qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) + elif len(row) >= 1 and row[0].strip(): + failed_rows.append(i + 1) + + elif filename.endswith(".xlsx") or filename.endswith(".xls"): + try: + import openpyxl + wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) + for sheet in wb.worksheets: + for i, row in enumerate(sheet.iter_rows(values_only=True)): + if i == 0: + continue + if len(row) >= 2 and row[0] and row[1]: + q = str(row[0]).strip() + a = str(row[1]).strip() + if q and a: + qa_pairs.append({"question": q, "answer": a}) + elif len(row) >= 1 and row[0]: + failed_rows.append(i + 1) + wb.close() + except Exception as e: + logger.error(f"[ImportQA] Excel parse failed: {e}") + return {"error": f"Excel parse failed: {e}", "imported": 0} + + if not qa_pairs: + logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") + return {"error": "No valid QA pairs found", "imported": 0} + + logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") + + # 2. 写入 ES + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] + + chunks = [] + for pair in qa_pairs: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": document_id, + "knowledge_id": kb_id, + "sort_id": sort_id, + "status": 1, + "chunk_type": "qa", + "question": pair["question"], + "answer": pair["answer"], + } + chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) + + batch_size = 50 + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + vector_service.add_chunks(batch) + + # 3. 更新 chunk_num + db_document.chunk_num += len(chunks) + db.commit() + + result = {"imported": len(chunks), "failed_rows": failed_rows} + logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") + return result + + except Exception as e: + logger.error(f"[ImportQA] Failed: {e}", exc_info=True) + return {"error": str(e), "imported": 0} + finally: + db.close() + + @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") def sync_knowledge_for_kb(kb_id: uuid.UUID): """ From 6e89302cb29432b78bf2fb721a468d91a850276b Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Wed, 29 Apr 2026 11:44:03 +0800 Subject: [PATCH 06/11] no message --- api/app/controllers/chunk_controller.py | 63 +++++++++ api/app/tasks.py | 174 ++++++++++++------------ 2 files changed, 150 insertions(+), 87 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index cb927504..fe383cb1 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -348,6 +348,69 @@ async def create_chunks_batch( return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") +@router.post("/{kb_id}/import_qa", response_model=ApiResponse) +async def import_qa_new_doc( + kb_id: uuid.UUID, + file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 导入 QA 问答对并新建文档(CSV/Excel),异步处理 + """ + api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}") + + # 1. 校验文件格式 + filename = file.filename or "" + if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式") + + # 2. 校验知识库 + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问") + + # 3. 创建 File 记录 + from app.schemas import file_schema, document_schema + _, file_extension = os.path.splitext(filename) + file_ext = file_extension.lower() + contents = await file.read() + file_size = len(contents) + + file_data = file_schema.FileCreate( + kb_id=kb_id, created_by=current_user.id, + parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + file_name=filename, file_ext=file_ext, file_size=file_size, + ) + db_file = file_service.create_file(db=db, file=file_data, current_user=current_user) + + # 4. 创建 Document 记录 + doc_data = document_schema.DocumentCreate( + kb_id=kb_id, created_by=current_user.id, file_id=db_file.id, + file_name=filename, file_ext=file_ext, file_size=file_size, + file_meta={}, parser_id="naive", + parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, + "delimiter": "\n", "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"} + ) + db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user) + + api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}") + + # 5. 派发异步任务 + from app.celery_app import celery_app + task = celery_app.send_task( + "app.core.rag.tasks.import_qa_chunks", + args=[str(kb_id), str(db_document.id), filename, contents], + queue="qa_import" + ) + + return success(data={ + "task_id": task.id, + "document_id": str(db_document.id), + "file_id": str(db_file.id), + }, msg="QA 导入任务已提交,后台处理中") + + @router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse) async def import_qa_chunks( kb_id: uuid.UUID, diff --git a/api/app/tasks.py b/api/app/tasks.py index 2fcab818..48368b76 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -707,111 +707,111 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte import csv as csv_module import io - db = SessionLocal() + db = None try: - db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() - if not db_document or not db_knowledge: - logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") - return {"error": "document or knowledge not found", "imported": 0} + from app.db import get_db_context + with get_db_context() as db: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() + if not db_document or not db_knowledge: + logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") + return {"error": "document or knowledge not found", "imported": 0} - # 1. 解析文件 - qa_pairs = [] - failed_rows = [] + # 1. 解析文件 + qa_pairs = [] + failed_rows = [] - if filename.endswith(".csv"): - try: - text = contents.decode("utf-8-sig") - except UnicodeDecodeError: - text = contents.decode("gbk", errors="ignore") + if filename.endswith(".csv"): + try: + text = contents.decode("utf-8-sig") + except UnicodeDecodeError: + text = contents.decode("gbk", errors="ignore") - sniffer = csv_module.Sniffer() - try: - dialect = sniffer.sniff(text[:2048]) - delimiter = dialect.delimiter - except csv_module.Error: - delimiter = "," if "," in text[:500] else "\t" + sniffer = csv_module.Sniffer() + try: + dialect = sniffer.sniff(text[:2048]) + delimiter = dialect.delimiter + except csv_module.Error: + delimiter = "," if "," in text[:500] else "\t" - reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) - for i, row in enumerate(reader): - if i == 0: - continue - if len(row) >= 2 and row[0].strip() and row[1].strip(): - qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) - elif len(row) >= 1 and row[0].strip(): - failed_rows.append(i + 1) + reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) + for i, row in enumerate(reader): + if i == 0: + continue + if len(row) >= 2 and row[0].strip() and row[1].strip(): + qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) + elif len(row) >= 1 and row[0].strip(): + failed_rows.append(i + 1) - elif filename.endswith(".xlsx") or filename.endswith(".xls"): - try: - import openpyxl - wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) - for sheet in wb.worksheets: - for i, row in enumerate(sheet.iter_rows(values_only=True)): - if i == 0: - continue - if len(row) >= 2 and row[0] and row[1]: - q = str(row[0]).strip() - a = str(row[1]).strip() - if q and a: - qa_pairs.append({"question": q, "answer": a}) - elif len(row) >= 1 and row[0]: - failed_rows.append(i + 1) - wb.close() - except Exception as e: - logger.error(f"[ImportQA] Excel parse failed: {e}") - return {"error": f"Excel parse failed: {e}", "imported": 0} + elif filename.endswith(".xlsx") or filename.endswith(".xls"): + try: + import openpyxl + wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) + for sheet in wb.worksheets: + for i, row in enumerate(sheet.iter_rows(values_only=True)): + if i == 0: + continue + if len(row) >= 2 and row[0] and row[1]: + q = str(row[0]).strip() + a = str(row[1]).strip() + if q and a: + qa_pairs.append({"question": q, "answer": a}) + elif len(row) >= 1 and row[0]: + failed_rows.append(i + 1) + wb.close() + except Exception as e: + logger.error(f"[ImportQA] Excel parse failed: {e}") + return {"error": f"Excel parse failed: {e}", "imported": 0} - if not qa_pairs: - logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") - return {"error": "No valid QA pairs found", "imported": 0} + if not qa_pairs: + logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") + return {"error": "No valid QA pairs found", "imported": 0} - logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") + logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") - # 2. 写入 ES - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2. 写入 ES + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - sort_id = 0 - total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) - if items: - sort_id = items[0].metadata["sort_id"] + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] - chunks = [] - for pair in qa_pairs: - sort_id += 1 - doc_id = uuid.uuid4().hex - metadata = { - "doc_id": doc_id, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": document_id, - "knowledge_id": kb_id, - "sort_id": sort_id, - "status": 1, - "chunk_type": "qa", - "question": pair["question"], - "answer": pair["answer"], - } - chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) + chunks = [] + for pair in qa_pairs: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": document_id, + "knowledge_id": kb_id, + "sort_id": sort_id, + "status": 1, + "chunk_type": "qa", + "question": pair["question"], + "answer": pair["answer"], + } + chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) - batch_size = 50 - for i in range(0, len(chunks), batch_size): - batch = chunks[i:i + batch_size] - vector_service.add_chunks(batch) + batch_size = 50 + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + vector_service.add_chunks(batch) - # 3. 更新 chunk_num - db_document.chunk_num += len(chunks) - db.commit() + # 3. 更新 chunk_num + db_document.chunk_num += len(chunks) + db.commit() - result = {"imported": len(chunks), "failed_rows": failed_rows} - logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") - return result + result = {"imported": len(chunks), "failed_rows": failed_rows} + logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") + return result except Exception as e: logger.error(f"[ImportQA] Failed: {e}", exc_info=True) return {"error": str(e), "imported": 0} - finally: - db.close() @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") From 5fceba54b4d15ff6176b67ee98fc046d3040c597 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Wed, 29 Apr 2026 13:41:14 +0800 Subject: [PATCH 07/11] [fix] file upload --- api/app/controllers/chunk_controller.py | 40 ++++++++++++++++++------- api/app/tasks.py | 15 +++++++++- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index fe383cb1..a0b985bf 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -25,6 +25,7 @@ from app.models.user_model import User from app.schemas import chunk_schema from app.schemas.response_schema import ApiResponse from app.services import knowledge_service, document_service, file_service, knowledgeshare_service +from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key from app.services.model_service import ModelApiKeyService # Obtain a dedicated API logger @@ -353,11 +354,14 @@ async def import_qa_new_doc( kb_id: uuid.UUID, file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"), db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + storage_service: FileStorageService = Depends(get_file_storage_service), ): """ 导入 QA 问答对并新建文档(CSV/Excel),异步处理 """ + from app.schemas import file_schema, document_schema + api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}") # 1. 校验文件格式 @@ -370,13 +374,16 @@ async def import_qa_new_doc( if not db_knowledge: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问") - # 3. 创建 File 记录 - from app.schemas import file_schema, document_schema - _, file_extension = os.path.splitext(filename) - file_ext = file_extension.lower() + # 3. 读取文件 contents = await file.read() file_size = len(contents) + if file_size == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空") + _, file_extension = os.path.splitext(filename) + file_ext = file_extension.lower() + + # 4. 创建 File 记录 file_data = file_schema.FileCreate( kb_id=kb_id, created_by=current_user.id, parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), @@ -384,19 +391,30 @@ async def import_qa_new_doc( ) db_file = file_service.create_file(db=db, file=file_data, current_user=current_user) - # 4. 创建 Document 记录 + # 5. 上传文件到存储后端 + file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext) + try: + await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type) + except Exception as e: + api_logger.error(f"Storage upload failed: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}") + + db_file.file_key = file_key + db.commit() + db.refresh(db_file) + + # 6. 创建 Document 记录(标记为 QA 类型) doc_data = document_schema.DocumentCreate( kb_id=kb_id, created_by=current_user.id, file_id=db_file.id, file_name=filename, file_ext=file_ext, file_size=file_size, - file_meta={}, parser_id="naive", - parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, - "delimiter": "\n", "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"} + file_meta={}, parser_id="qa", + parser_config={"doc_type": "qa", "auto_questions": 0} ) db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user) - api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}") + api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}") - # 5. 派发异步任务 + # 7. 派发异步任务 from app.celery_app import celery_app task = celery_app.send_task( "app.core.rag.tasks.import_qa_chunks", diff --git a/api/app/tasks.py b/api/app/tasks.py index 48368b76..77e20e2c 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -801,8 +801,10 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte batch = chunks[i:i + batch_size] vector_service.add_chunks(batch) - # 3. 更新 chunk_num + # 3. 更新 chunk_num 和 progress db_document.chunk_num += len(chunks) + db_document.progress = 1.0 + db_document.progress_msg = f"QA 导入完成: {len(chunks)} 条" db.commit() result = {"imported": len(chunks), "failed_rows": failed_rows} @@ -811,6 +813,17 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte except Exception as e: logger.error(f"[ImportQA] Failed: {e}", exc_info=True) + # 尝试更新文档状态为失败 + try: + from app.db import get_db_context + with get_db_context() as err_db: + doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + if doc: + doc.progress = -1.0 + doc.progress_msg = f"QA 导入失败: {str(e)[:200]}" + err_db.commit() + except Exception: + pass return {"error": str(e), "imported": 0} From f85c0594c9c0364ac0697d110898d9d3ffe9984b Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Wed, 29 Apr 2026 15:24:25 +0800 Subject: [PATCH 08/11] [fix] es vector --- api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 3f64ad85..cd52550b 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -709,7 +709,7 @@ class ElasticSearchVector(BaseVector): }, Field.VECTOR.value: { "type": "dense_vector", - "dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency + "dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度,fallback 768 "index": True, "similarity": "cosine" } From 70c6d161c83778446455bf401e3fcac97438e19d Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Wed, 6 May 2026 15:19:46 +0800 Subject: [PATCH 09/11] [fix] delete chunk refresh index --- api/app/controllers/chunk_controller.py | 3 ++- api/app/controllers/service/rag_api_chunk_controller.py | 2 ++ .../core/rag/vdb/elasticsearch/elasticsearch_vector.py | 8 ++++++-- api/app/core/rag/vdb/vector_base.py | 4 ++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index a0b985bf..cfe36a3a 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -546,6 +546,7 @@ async def delete_chunk( kb_id: uuid.UUID, document_id: uuid.UUID, doc_id: str, + force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -563,7 +564,7 @@ async def delete_chunk( vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) if vector_service.text_exists(doc_id): - vector_service.delete_by_ids([doc_id]) + vector_service.delete_by_ids([doc_id], refresh=force_refresh) # 更新 chunk_num db_document = db.query(Document).filter(Document.id == document_id).first() db_document.chunk_num -= 1 diff --git a/api/app/controllers/service/rag_api_chunk_controller.py b/api/app/controllers/service/rag_api_chunk_controller.py index a4d9a20c..c9f5e7de 100644 --- a/api/app/controllers/service/rag_api_chunk_controller.py +++ b/api/app/controllers/service/rag_api_chunk_controller.py @@ -176,6 +176,7 @@ async def delete_chunk( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), + force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"), ): """ delete document chunk @@ -188,6 +189,7 @@ async def delete_chunk( return await chunk_controller.delete_chunk(kb_id=kb_id, document_id=document_id, doc_id=doc_id, + force_refresh=force_refresh, db=db, current_user=current_user) diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index cd52550b..5f9e86c5 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector): return True - def delete_by_ids(self, ids: list[str]): + def delete_by_ids(self, ids: list[str], *, refresh: bool = False): if not ids: return if not self._client.indices.exists(index=self._collection_name): @@ -163,6 +163,8 @@ class ElasticSearchVector(BaseVector): actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] try: helpers.bulk(self._client, actions) + if refresh: + self._client.indices.refresh(index=self._collection_name) except BulkIndexError as e: for error in e.errors: delete_error = error.get('delete', {}) @@ -182,7 +184,7 @@ class ElasticSearchVector(BaseVector): else: return None - def delete_by_metadata_field(self, key: str, value: str): + def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False): if not self._client.indices.exists(index=self._collection_name): return False actual_ids = self.get_ids_by_metadata_field(key, value) @@ -191,6 +193,8 @@ class ElasticSearchVector(BaseVector): actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] try: helpers.bulk(self._client, actions) + if refresh: + self._client.indices.refresh(index=self._collection_name) except BulkIndexError as e: for error in e.errors: delete_error = error.get('delete', {}) diff --git a/api/app/core/rag/vdb/vector_base.py b/api/app/core/rag/vdb/vector_base.py index df3ac7d8..266a3f40 100644 --- a/api/app/core/rag/vdb/vector_base.py +++ b/api/app/core/rag/vdb/vector_base.py @@ -27,14 +27,14 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def delete_by_ids(self, ids: list[str]): + def delete_by_ids(self, ids: list[str], *, refresh: bool = False): raise NotImplementedError def get_ids_by_metadata_field(self, key: str, value: str): raise NotImplementedError @abstractmethod - def delete_by_metadata_field(self, key: str, value: str): + def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False): raise NotImplementedError @abstractmethod From ad2e885f721fc54a27ffaea7a6c9f0a8e6eb766e Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Wed, 6 May 2026 18:34:07 +0800 Subject: [PATCH 10/11] [fix] index_not_found_exception --- .../vdb/elasticsearch/elasticsearch_vector.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 5f9e86c5..03e0ed46 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -5,7 +5,7 @@ from typing import Any from urllib.parse import urlparse import requests -from elasticsearch import Elasticsearch, helpers +from elasticsearch import Elasticsearch, helpers, NotFoundError from elasticsearch.helpers import BulkIndexError from packaging.version import parse as parse_version # langchain-community @@ -225,6 +225,8 @@ class ElasticSearchVector(BaseVector): List of DocumentChunk objects that match the query. """ indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3" + if not self._client.indices.exists(index=indices): + return 0, [] # Calculate the start position for the current page from_ = pagesize * (page-1) @@ -259,12 +261,15 @@ class ElasticSearchVector(BaseVector): }) # For simplicity, we use from/size here which has a limit (usually up to 10,000). - result = self._client.search( - index=indices, - from_=from_, # Only use from_ for the first page (simplified) - size=pagesize, - body=query_str, - ) + try: + result = self._client.search( + index=indices, + from_=from_, # Only use from_ for the first page (simplified) + size=pagesize, + body=query_str, + ) + except NotFoundError: + return 0, [] if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -309,13 +314,18 @@ class ElasticSearchVector(BaseVector): List of DocumentChunk objects that match the query. """ indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" + if not self._client.indices.exists(index=indices): + return 0, [] query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}} - result = self._client.search( - index=indices, - from_=0, # Only use from_ for the first page (simplified) - size=1, - body=query_str, - ) + try: + result = self._client.search( + index=indices, + from_=0, # Only use from_ for the first page (simplified) + size=1, + body=query_str, + ) + except NotFoundError: + return 0, [] # print(result) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") From e222490bce4d314cb01d82776b45fb99680a0093 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Thu, 7 May 2026 18:45:36 +0800 Subject: [PATCH 11/11] [add] batch add chunk for v1 --- .../service/rag_api_chunk_controller.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/api/app/controllers/service/rag_api_chunk_controller.py b/api/app/controllers/service/rag_api_chunk_controller.py index c9f5e7de..689128be 100644 --- a/api/app/controllers/service/rag_api_chunk_controller.py +++ b/api/app/controllers/service/rag_api_chunk_controller.py @@ -113,6 +113,33 @@ async def create_chunk( current_user=current_user) +@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse) +@require_api_key(scopes=["rag"]) +async def create_chunks_batch( + kb_id: uuid.UUID, + document_id: uuid.UUID, + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + items: list = Body(..., description="chunk items list"), +): + """ + Batch create chunks (max 8) + """ + body = await request.json() + batch_data = chunk_schema.ChunkBatchCreate(**body) + # 0. Obtain the creator of the api key + api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id) + current_user = api_key.creator + current_user.current_workspace_id = api_key_auth.workspace_id + + return await chunk_controller.create_chunks_batch(kb_id=kb_id, + document_id=document_id, + batch_data=batch_data, + db=db, + current_user=current_user) + + @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) @require_api_key(scopes=["rag"]) async def get_chunk(