Compare commits

..

5 Commits

Author SHA1 Message Date
wxy
cef33fce0d fix(workflow): sanitize condition expression building and cache assigner node inputs
- Sanitize condition expression construction in graph_builder.py using json.dumps to prevent potential injection vulnerabilities.
- Cache input data prior to assigner node execution to ensure variable values are correctly captured before processing.
2026-05-07 16:26:47 +08:00
wxy
d9f08860bc feat(LLM node): integrate exception handling and enable branch routing
- Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch.
- Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure.
- Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment.
- Implement backward compatibility to support both legacy and new result formats.
2026-05-07 11:43:24 +08:00
wxy
461674c8d8 feat(workflow): parse and substitute template variables in node configurations
- Implement regex matching for {{xxx}} template variable format.
- Enable recursive parsing of all string template variables within node configurations.
- Resolve and substitute template variables with runtime values during input data extraction.
- Support dynamic parsing and substitution of file selector variables in the document extraction node.
- Make strict template variable mode optional and introduce support for default values.
2026-04-29 14:10:02 +08:00
wxy
c59e179cc2 feat(workflow): incorporate model references and streamline parsing logic
- Incorporate model reference metadata (name, provider, type) into workflow nodes and refactor parsing logic to support the new format.
- Streamline code structure by removing redundant model_id fields to enhance maintainability.
2026-04-28 11:18:06 +08:00
Mark
a5670bfff6 Merge branch 'feature/rag2' into develop 2026-04-27 18:17:49 +08:00
19 changed files with 388 additions and 789 deletions

View File

@@ -1,10 +1,8 @@
import os import os
import csv
import io
from typing import Any, Optional from typing import Any, Optional
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -25,7 +23,6 @@ from app.models.user_model import User
from app.schemas import chunk_schema from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger # Obtain a dedicated API logger
@@ -274,9 +271,6 @@ async def create_chunk(
"sort_id": sort_id, "sort_id": sort_id,
"status": 1, "status": 1,
} }
# QA chunk: 注入 chunk_type/question/answer 到 metadata
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunk = DocumentChunk(page_content=content, metadata=metadata) chunk = DocumentChunk(page_content=content, metadata=metadata)
# 3. Segmented vector storage # 3. Segmented vector storage
vector_service.add_chunks([chunk]) vector_service.add_chunks([chunk])
@@ -288,187 +282,6 @@ async def create_chunk(
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful") return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
batch_data: chunk_schema.ChunkBatchCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Batch create chunks (max 8)
"""
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Get current max sort_id
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for create_data in batch_data.items:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
vector_service.add_chunks(chunks)
db_document.chunk_num += len(chunks)
db.commit()
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
async def import_qa_new_doc(
kb_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
导入 QA 问答对并新建文档CSV/Excel异步处理
"""
from app.schemas import file_schema, document_schema
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
# 3. 读取文件
contents = await file.read()
file_size = len(contents)
if file_size == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
# 4. 创建 File 记录
file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id,
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
file_name=filename, file_ext=file_ext, file_size=file_size,
)
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
# 5. 上传文件到存储后端
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
try:
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# 6. 创建 Document 记录(标记为 QA 类型)
doc_data = document_schema.DocumentCreate(
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=filename, file_ext=file_ext, file_size=file_size,
file_meta={}, parser_id="qa",
parser_config={"doc_type": "qa", "auto_questions": 0}
)
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
# 7. 派发异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(db_document.id), filename, contents],
queue="qa_import"
)
return success(data={
"task_id": task.id,
"document_id": str(db_document.id),
"file_id": str(db_file.id),
}, msg="QA 导入任务已提交,后台处理中")
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
async def import_qa_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
导入 QA 问答对CSV/Excel异步处理
"""
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库和文档
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
# 3. 读取文件内容,派发异步任务
contents = await file.read()
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(document_id), filename, contents],
queue="qa_import"
)
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk( async def get_chunk(
kb_id: uuid.UUID, kb_id: uuid.UUID,
@@ -529,9 +342,6 @@ async def update_chunk(
if total: if total:
chunk = items[0] chunk = items[0]
chunk.page_content = content chunk.page_content = content
# QA chunk: 更新 metadata 中的 question/answer
if update_data.is_qa:
chunk.metadata.update(update_data.qa_metadata)
vector_service.update_by_segment(chunk) vector_service.update_by_segment(chunk)
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated") return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
else: else:
@@ -546,7 +356,6 @@ async def delete_chunk(
kb_id: uuid.UUID, kb_id: uuid.UUID,
document_id: uuid.UUID, document_id: uuid.UUID,
doc_id: str, doc_id: str,
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
@@ -564,7 +373,7 @@ async def delete_chunk(
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if vector_service.text_exists(doc_id): if vector_service.text_exists(doc_id):
vector_service.delete_by_ids([doc_id], refresh=force_refresh) vector_service.delete_by_ids([doc_id])
# 更新 chunk_num # 更新 chunk_num
db_document = db.query(Document).filter(Document.id == document_id).first() db_document = db.query(Document).filter(Document.id == document_id).first()
db_document.chunk_num -= 1 db_document.chunk_num -= 1

View File

@@ -113,33 +113,6 @@ async def create_chunk(
current_user=current_user) current_user=current_user)
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
items: list = Body(..., description="chunk items list"),
):
"""
Batch create chunks (max 8)
"""
body = await request.json()
batch_data = chunk_schema.ChunkBatchCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
document_id=document_id,
batch_data=batch_data,
db=db,
current_user=current_user)
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"]) @require_api_key(scopes=["rag"])
async def get_chunk( async def get_chunk(
@@ -203,7 +176,6 @@ async def delete_chunk(
request: Request, request: Request,
api_key_auth: ApiKeyAuth = None, api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
): ):
""" """
delete document chunk delete document chunk
@@ -216,7 +188,6 @@ async def delete_chunk(
return await chunk_controller.delete_chunk(kb_id=kb_id, return await chunk_controller.delete_chunk(kb_id=kb_id,
document_id=document_id, document_id=document_id,
doc_id=doc_id, doc_id=doc_id,
force_refresh=force_refresh,
db=db, db=db,
current_user=current_user) current_user=current_user)

View File

@@ -98,7 +98,6 @@ class Settings:
# File Upload # File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800")) MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20")) MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files") FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600")) FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))

View File

@@ -46,10 +46,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"] workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True): for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
# 跳过 QA chunks只用原文 chunks 构建图谱
if d.get("chunk_type") == "qa":
continue
chunks.append(d["page_content"]) chunks.append(d["page_content"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@@ -153,9 +150,6 @@ async def run_graphrag_for_kb(
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True) total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
for doc in items: for doc in items:
# 跳过 QA chunks只用原文 chunks 构建图谱
if (doc.metadata or {}).get("chunk_type") == "qa":
continue
content = doc.page_content content = doc.page_content
if num_tokens_from_string(current_chunk + content) < 1024: if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content current_chunk += content

View File

@@ -131,52 +131,18 @@ def keyword_extraction(chat_mdl, content, topn=3):
def question_proposal(chat_mdl, content, topn=3): def question_proposal(chat_mdl, content, topn=3):
"""生成问题(向后兼容,返回纯文本问题列表)""" template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
pairs = qa_proposal(chat_mdl, content, topn) rendered_prompt = template.render(content=content, topn=topn)
if not pairs:
return ""
return "\n".join([p["question"] for p in pairs])
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
Args:
chat_mdl: LLM 模型
content: 文本内容
topn: 生成 QA 对数量
custom_prompt: 自定义 prompt 模板(支持 Jinja2可用变量: content, topn
"""
if custom_prompt:
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
sys_prompt = template.render(topn=topn)
else:
sys_prompt = QUESTION_PROMPT_TEMPLATE
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096)) _, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(raw, tuple): if isinstance(kwd, tuple):
raw = raw[0] kwd = kwd[0]
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if raw.find("**ERROR**") >= 0: if kwd.find("**ERROR**") >= 0:
return [] return ""
return parse_qa_pairs(raw) return kwd
def parse_qa_pairs(text: str) -> list:
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
pairs = []
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
# 匹配 Q: ... A: ... 格式
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
if match:
q, a = match.group(1).strip(), match.group(2).strip()
if q and a:
pairs.append({"question": q, "answer": a})
return pairs
def graph_entity_types(chat_mdl, scenario): def graph_entity_types(chat_mdl, scenario):

View File

@@ -1,20 +1,19 @@
## Role ## Role
You are a text analyzer and knowledge extraction expert. You are a text analyzer.
## Task ## Task
Generate question-answer pairs from the given text content. Propose {{ topn }} questions about a given piece of text content.
## Requirements ## Requirements
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs. - Understand and summarize the text content, and propose the top {{ topn }} important questions.
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
- The questions SHOULD NOT have overlapping meanings. - The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible. - The questions SHOULD cover the main content of the text as much as possible.
- The answers MUST be concise, accurate, and directly derived from the text content. - The questions MUST be in the same language as the given piece of text content.
- The answers SHOULD be self-contained and understandable without additional context. - One question per line.
- Both questions and answers MUST be in the same language as the given text content. - Output questions ONLY.
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
- Output question-answer pairs ONLY, no extra explanation or commentary. ---
## Text Content
{{ content }}
## Example Output
Q: What is the capital of France? A: The capital of France is Paris.
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.

View File

@@ -5,7 +5,7 @@ from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
from elasticsearch import Elasticsearch, helpers, NotFoundError from elasticsearch import Elasticsearch, helpers
from elasticsearch.helpers import BulkIndexError from elasticsearch.helpers import BulkIndexError
from packaging.version import parse as parse_version from packaging.version import parse as parse_version
# langchain-community # langchain-community
@@ -53,30 +53,13 @@ class ElasticSearchVector(BaseVector):
return "elasticsearch" return "elasticsearch"
def add_chunks(self, chunks: list[DocumentChunk], **kwargs): def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# QA chunks: embedding 只对 question 字段做source chunks: 不做 embedding # 实现 Elasticsearch 保存向量
texts_for_embedding = [] texts = [chunk.page_content for chunk in chunks]
for chunk in chunks:
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
if chunk_type == "source":
# source chunk 不需要向量索引
texts_for_embedding.append("")
elif chunk_type == "qa":
# QA chunk: 用 question 字段做 embedding
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
else:
# 普通 chunk: 用 page_content 做 embedding
texts_for_embedding.append(chunk.page_content)
if self.is_multimodal_embedding: if self.is_multimodal_embedding:
embeddings = self.embeddings.embed_batch(texts_for_embedding) # 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
else: else:
embeddings = self.embeddings.embed_documents(texts_for_embedding) embeddings = self.embeddings.embed_documents(list(texts))
# source chunk 的向量置空
for i, chunk in enumerate(chunks):
if (chunk.metadata or {}).get("chunk_type") == "source":
embeddings[i] = None
self.create(chunks, embeddings, **kwargs) self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -89,25 +72,13 @@ class ElasticSearchVector(BaseVector):
uuids = self._get_uuids(chunks) uuids = self._get_uuids(chunks)
actions = [] actions = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
source = {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
# 写入 QA 相关字段
meta = chunk.metadata or {}
if meta.get("chunk_type"):
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
if meta.get("question"):
source[Field.QUESTION.value] = meta["question"]
if meta.get("answer"):
source[Field.ANSWER.value] = meta["answer"]
if meta.get("source_chunk_id"):
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
action = { action = {
"_index": self._collection_name, "_index": self._collection_name,
"_source": source "_source": {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
} }
actions.append(action) actions.append(action)
# using bulk mode # using bulk mode
@@ -142,7 +113,7 @@ class ElasticSearchVector(BaseVector):
return True return True
def delete_by_ids(self, ids: list[str], *, refresh: bool = False): def delete_by_ids(self, ids: list[str]):
if not ids: if not ids:
return return
if not self._client.indices.exists(index=self._collection_name): if not self._client.indices.exists(index=self._collection_name):
@@ -163,8 +134,6 @@ class ElasticSearchVector(BaseVector):
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
try: try:
helpers.bulk(self._client, actions) helpers.bulk(self._client, actions)
if refresh:
self._client.indices.refresh(index=self._collection_name)
except BulkIndexError as e: except BulkIndexError as e:
for error in e.errors: for error in e.errors:
delete_error = error.get('delete', {}) delete_error = error.get('delete', {})
@@ -184,7 +153,7 @@ class ElasticSearchVector(BaseVector):
else: else:
return None return None
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False): def delete_by_metadata_field(self, key: str, value: str):
if not self._client.indices.exists(index=self._collection_name): if not self._client.indices.exists(index=self._collection_name):
return False return False
actual_ids = self.get_ids_by_metadata_field(key, value) actual_ids = self.get_ids_by_metadata_field(key, value)
@@ -193,8 +162,6 @@ class ElasticSearchVector(BaseVector):
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
try: try:
helpers.bulk(self._client, actions) helpers.bulk(self._client, actions)
if refresh:
self._client.indices.refresh(index=self._collection_name)
except BulkIndexError as e: except BulkIndexError as e:
for error in e.errors: for error in e.errors:
delete_error = error.get('delete', {}) delete_error = error.get('delete', {})
@@ -225,8 +192,6 @@ class ElasticSearchVector(BaseVector):
List of DocumentChunk objects that match the query. List of DocumentChunk objects that match the query.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
if not self._client.indices.exists(index=indices):
return 0, []
# Calculate the start position for the current page # Calculate the start position for the current page
from_ = pagesize * (page-1) from_ = pagesize * (page-1)
@@ -261,15 +226,12 @@ class ElasticSearchVector(BaseVector):
}) })
# For simplicity, we use from/size here which has a limit (usually up to 10,000). # For simplicity, we use from/size here which has a limit (usually up to 10,000).
try: result = self._client.search(
result = self._client.search( index=indices,
index=indices, from_=from_, # Only use from_ for the first page (simplified)
from_=from_, # Only use from_ for the first page (simplified) size=pagesize,
size=pagesize, body=query_str,
body=query_str, )
)
except NotFoundError:
return 0, []
if "errors" in result: if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}") raise ValueError(f"Error during query: {result['errors']}")
@@ -279,19 +241,10 @@ class ElasticSearchVector(BaseVector):
for res in result["hits"]["hits"]: for res in result["hits"]["hits"]:
source = res["_source"] source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value) page_content = source.get(Field.CONTENT_KEY.value)
# vector = source.get(Field.VECTOR.value)
vector = None vector = None
metadata = source.get(Field.METADATA_KEY.value, {}) metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"] score = res["_score"]
# 将 QA 字段注入 metadata 供前端展示
if chunk_type:
metadata["chunk_type"] = chunk_type
if chunk_type == "qa":
metadata["question"] = source.get(Field.QUESTION.value, "")
metadata["answer"] = source.get(Field.ANSWER.value, "")
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score)) docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
docs = [] docs = []
@@ -314,18 +267,13 @@ class ElasticSearchVector(BaseVector):
List of DocumentChunk objects that match the query. List of DocumentChunk objects that match the query.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
if not self._client.indices.exists(index=indices):
return 0, []
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}} query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
try: result = self._client.search(
result = self._client.search( index=indices,
index=indices, from_=0, # Only use from_ for the first page (simplified)
from_=0, # Only use from_ for the first page (simplified) size=1,
size=1, body=query_str,
body=query_str, )
)
except NotFoundError:
return 0, []
# print(result) # print(result)
if "errors" in result: if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}") raise ValueError(f"Error during query: {result['errors']}")
@@ -360,43 +308,27 @@ class ElasticSearchVector(BaseVector):
Returns: Returns:
updated count. updated count.
""" """
indices = kwargs.get("indices", self._collection_name) indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
chunk_type = (chunk.metadata or {}).get("chunk_type") if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
# QA chunk: embedding 基于 questionsource chunk: 不更新向量 chunk.vector = self.embeddings.embed_text(chunk.page_content)
if chunk_type == "source":
embed_text = ""
elif chunk_type == "qa":
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
else: else:
embed_text = chunk.page_content chunk.vector = self.embeddings.embed_query(chunk.page_content)
if chunk_type != "source":
if self.is_multimodal_embedding:
chunk.vector = self.embeddings.embed_text(embed_text)
else:
chunk.vector = self.embeddings.embed_query(embed_text)
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
params = {
"new_content": chunk.page_content,
"new_vector": chunk.vector if chunk_type != "source" else None
}
# QA chunk: 同时更新 question/answer 字段
if chunk_type == "qa":
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
params["new_question"] = (chunk.metadata or {}).get("question", "")
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
body = { body = {
"script": { "script": {
"source": script_source, "source": """
"params": params ctx._source.page_content = params.new_content;
ctx._source.vector = params.new_vector;
""",
"params": {
"new_content": chunk.page_content,
"new_vector": chunk.vector
}
}, },
"query": { "query": {
"term": { "term": {
Field.DOC_ID.value: chunk.metadata["doc_id"] Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
} }
} }
} }
@@ -404,6 +336,9 @@ class ElasticSearchVector(BaseVector):
index=indices, index=indices,
body=body, body=body,
) )
# Remove debug printing and use logging instead
# print(result)
# print(f"Update successful, number of affected documents: {result['updated']}")
return result['updated'] return result['updated']
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str: def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
@@ -462,11 +397,11 @@ class ElasticSearchVector(BaseVector):
} }
} }
}, },
"filter": [ "filter": { # Add the filter condition of status=1
{"term": {"metadata.status": 1}}, "term": {
# 排除 source chunk仅供 GraphRAG 使用,不参与检索) "metadata.status": 1
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} }
] }
} }
} }
# If file_names_filter is passed in, merge the filtering conditions # If file_names_filter is passed in, merge the filtering conditions
@@ -480,14 +415,22 @@ class ElasticSearchVector(BaseVector):
}, },
"script": { "script": {
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0", "source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
"params": {"query_vector": query_vector} "params": {"query_vector": query_vector}
} }
} }
}, },
"filter": [ "filter": [
{"term": {"metadata.status": 1}}, {
{"terms": {"metadata.file_name": file_names_filter}}, "term": {
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} "metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
], ],
} }
} }
@@ -508,19 +451,8 @@ class ElasticSearchVector(BaseVector):
source = res["_source"] source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value) page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {}) metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"] score = res["_score"]
score = score / 2 # Normalized [0-1] score = score / 2 # Normalized [0-1]
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score)) docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
docs = [] docs = []
@@ -559,10 +491,11 @@ class ElasticSearchVector(BaseVector):
} }
} }
}, },
"filter": [ "filter": { # Add the filter condition of status=1
{"term": {"metadata.status": 1}}, "term": {
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} "metadata.status": 1
] }
}
} }
} }
@@ -579,9 +512,16 @@ class ElasticSearchVector(BaseVector):
} }
}, },
"filter": [ "filter": [
{"term": {"metadata.status": 1}}, {
{"terms": {"metadata.file_name": file_names_filter}}, "term": {
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}} "metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
], ],
} }
} }
@@ -603,17 +543,6 @@ class ElasticSearchVector(BaseVector):
source = res["_source"] source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value) page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {}) metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
# Normalize the score to the [0,1] interval # Normalize the score to the [0,1] interval
normalized_score = res["_score"] / max_score normalized_score = res["_score"] / max_score
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score)) docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
@@ -723,7 +652,7 @@ class ElasticSearchVector(BaseVector):
}, },
Field.VECTOR.value: { Field.VECTOR.value: {
"type": "dense_vector", "type": "dense_vector",
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度fallback 768 "dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
"index": True, "index": True,
"similarity": "cosine" "similarity": "cosine"
} }

View File

@@ -14,8 +14,3 @@ class Field(StrEnum):
DOCUMENT_ID = "metadata.document_id" DOCUMENT_ID = "metadata.document_id"
KNOWLEDGE_ID = "metadata.knowledge_id" KNOWLEDGE_ID = "metadata.knowledge_id"
SORT_ID = "metadata.sort_id" SORT_ID = "metadata.sort_id"
# QA fields
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
QUESTION = "question"
ANSWER = "answer"
SOURCE_CHUNK_ID = "source_chunk_id"

View File

@@ -27,14 +27,14 @@ class BaseVector(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_ids(self, ids: list[str], *, refresh: bool = False): def delete_by_ids(self, ids: list[str]):
raise NotImplementedError raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False): def delete_by_metadata_field(self, key: str, value: str):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
import json
import logging import logging
import re import re
import uuid import uuid
@@ -141,9 +142,10 @@ class GraphBuilder:
for node_info in source_nodes: for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES: if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append( if node_info.get("branch") is not None:
(node_info["id"], node_info["branch"]) branch_nodes.append(
) (node_info["id"], node_info["branch"])
)
else: else:
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT): if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
output_nodes.append(node_info["id"]) output_nodes.append(node_info["id"])
@@ -314,9 +316,12 @@ class GraphBuilder:
for idx in range(len(related_edge)): for idx in range(len(related_edge)):
# Generate a condition expression for each edge # Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output # Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label # For LLM nodes, use branch_signal field for routing (output is dynamic text)
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' # For other branch nodes (e.g. HTTP), use output field
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" route_field = "branch_signal" if node_type == NodeType.LLM else "output"
related_edge[idx]['condition'] = (
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
)
if node_instance: if node_instance:
# Wrap node's run method to avoid closure issues # Wrap node's run method to avoid closure issues

View File

@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None
self._input_data: dict[str, Any] | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {} return {}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
if self._input_data is not None:
return self._input_data
return {"config": self._resolve_config(self.config, variable_pool)}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
Execute the assignment operation defined by this node. Execute the assignment operation defined by this node.
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
Returns: Returns:
None or the result of the assignment operation. None or the result of the assignment operation.
""" """
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
# Initialize a variable pool for accessing conversation, node, and system variables # Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config) self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行") logger.info(f"节点 {self.node_id} 开始执行")

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import re
import time import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 匹配模板变量 {{xxx}} 的正则
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class NodeExecutionError(Exception): class NodeExecutionError(Exception):
"""节点执行失败异常。 """节点执行失败异常。
@@ -503,10 +507,29 @@ class BaseNode(ABC):
variable_pool: The variable pool used for reading and writing variables. variable_pool: The variable pool used for reading and writing variables.
Returns: Returns:
A dictionary containing the node's input data. A dictionary containing the node's input data with all template
variables resolved to their actual runtime values.
""" """
# Default implementation returns the node configuration return {"config": self._resolve_config(self.config, variable_pool)}
return {"config": self.config}
@staticmethod
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
Args:
config: 节点的原始配置(可能包含模板变量)。
variable_pool: 变量池,用于解析模板变量。
Returns:
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
"""
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
return BaseNode._render_template(config, variable_pool, strict=False)
elif isinstance(config, dict):
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
elif isinstance(config, list):
return [BaseNode._resolve_config(item, variable_pool) for item in config]
return config
def _extract_output(self, business_result: Any) -> Any: def _extract_output(self, business_result: Any) -> Any:
"""Extracts the actual output from the business result. """Extracts the actual output from the business result.

View File

@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
return business_result return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")} file_selector = self.config.get("file_selector", "")
# 将变量选择器(如 sys.files解析为实际值
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
return {"file_selector": resolved}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config) config = DocExtractorNodeConfig(**self.config)

View File

@@ -31,7 +31,7 @@ class NodeType(StrEnum):
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -6,6 +6,7 @@ import uuid
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
) )
class LLMErrorHandleConfig(BaseModel):
"""LLM 异常处理配置"""
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
)
output: str = Field(
default="",
description="LLM 异常时返回的默认输出文本method=default 时生效)",
)
class LLMNodeConfig(BaseNodeConfig): class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置 """LLM 节点配置
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="输出变量定义(自动生成,通常不需要修改)" description="输出变量定义(自动生成,通常不需要修改)"
) )
error_handle: LLMErrorHandleConfig = Field(
default_factory=LLMErrorHandleConfig,
description="LLM 异常处理配置",
)
@field_validator("messages", "prompt") @field_validator("messages", "prompt")
@classmethod @classmethod
def validate_input_mode(cls, v): def validate_input_mode(cls, v):

View File

@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context from app.db import get_db_context
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
self.messages = [] self.messages = []
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
def _render_context(self, message: str, variable_pool: VariablePool): def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>" context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
return llm return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: async def execute(self, state: WorkflowState, variable_pool: VariablePool):
"""非流式执行 LLM 调用 """非流式执行 LLM 调用
Args: Args:
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
variable_pool: 变量池 variable_pool: 变量池
Returns: Returns:
LLM 响应消息 dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
{"llm_result": None, "branch_signal": "ERROR"} on branch error
""" """
# self.typed_config = LLMNodeConfig(**self.config) try:
llm = await self._prepare_llm(state, variable_pool, False) # self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表 # 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages) response = await llm.ainvoke(self.messages)
# 提取内容 # 提取内容
if hasattr(response, 'content'): if hasattr(response, 'content'):
content = self.process_model_output(response.content) content = self.process_model_output(response.content)
else: else:
content = str(response) content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据 # 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={ return {
**response.response_metadata, "llm_result": AIMessage(content=content, response_metadata={
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') **response.response_metadata,
}) "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}),
"branch_signal": "SUCCESS",
}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
return self._handle_llm_error(e)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)""" """提取输入数据(用于记录)"""
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
} }
} }
def _extract_output(self, business_result: Any) -> str: def _extract_output(self, business_result: Any) -> dict:
""" AIMessage 中提取文本内容""" """业务结果中提取输出变量
支持新旧两种格式:
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
- 旧格式AIMessage向后兼容
"""
if isinstance(business_result, dict) and "branch_signal" in business_result:
llm_result = business_result.get("llm_result")
if isinstance(llm_result, AIMessage):
return {
"output": llm_result.content,
"branch_signal": business_result["branch_signal"],
}
return {
"output": str(llm_result) if llm_result else "",
"branch_signal": business_result["branch_signal"],
}
# 旧格式向后兼容
if isinstance(business_result, AIMessage): if isinstance(business_result, AIMessage):
return business_result.content return {"output": business_result.content, "branch_signal": "SUCCESS"}
return str(business_result) return {"output": str(business_result), "branch_signal": "SUCCESS"}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
""" AIMessage 中提取 token 使用情况""" """业务结果中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): llm_result = business_result
usage = business_result.response_metadata.get('token_usage') if isinstance(business_result, dict):
llm_result = business_result.get("llm_result", business_result)
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
usage = llm_result.response_metadata.get('token_usage')
if usage: if usage:
return { return {
"prompt_tokens": usage.get('input_tokens', 0), "prompt_tokens": usage.get('input_tokens', 0),
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
} }
return None return None
def _handle_llm_error(self, error: Exception) -> dict:
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
Args:
error: LLM 调用中捕获的异常
Returns:
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
or default output for default mode
Raises:
原异常(当 error_handle.method 为 NONE 时)
"""
if self.typed_config is None:
raise error
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
raise error
case HttpErrorHandle.DEFAULT:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
)
default_output = self.typed_config.error_handle.output or ""
return {
"llm_result": AIMessage(content=default_output, response_metadata={}),
"branch_signal": "SUCCESS",
}
case HttpErrorHandle.BRANCH:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
)
return {
"llm_result": None,
"branch_signal": "ERROR",
}
raise error
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用 """流式执行 LLM 调用
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
""" """
self.typed_config = LLMNodeConfig(**self.config) self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, True) try:
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
chunk_count = 0 chunk_count = 0
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
last_usage_metadata = {} last_usage_metadata = {}
async for chunk in llm.astream(self.messages): async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content) content = self.process_model_output(chunk.content)
else: else:
content = str(chunk) content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata: if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理 # 只有当内容不为空时才处理
if content: if content:
full_response += content full_response += content
chunk_count += 1 chunk_count += 1
# 流式返回每个文本片段 # 流式返回每个文本片段
yield { yield {
"__final__": False, "__final__": False,
"chunk": content "chunk": content
} }
yield { yield {
"__final__": False, "__final__": False,
"chunk": "", "chunk": "",
"done": True "done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
} }
) logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# yield 完成标记 # 构建完整的 AIMessage包含元数据
yield {"__final__": True, "result": final_message} final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
error_result = self._handle_llm_error(e)
yield {"__final__": True, "result": error_result}

View File

@@ -20,26 +20,13 @@ class ChunkCreate(BaseModel):
@property @property
def chunk_content(self) -> str: def chunk_content(self) -> str:
"""Get the actual content string regardless of input type""" """
Get the actual content string regardless of input type
"""
if isinstance(self.content, QAChunk): if isinstance(self.content, QAChunk):
return self.content.question # QA 模式下 page_content 存 question return f"question: {self.content.question} answer: {self.content.answer}"
return self.content return self.content
@property
def is_qa(self) -> bool:
return isinstance(self.content, QAChunk)
@property
def qa_metadata(self) -> dict:
"""返回 QA 相关的 metadata 字段"""
if isinstance(self.content, QAChunk):
return {
"chunk_type": "qa",
"question": self.content.question,
"answer": self.content.answer,
}
return {}
class ChunkUpdate(BaseModel): class ChunkUpdate(BaseModel):
content: Union[str, QAChunk] = Field( content: Union[str, QAChunk] = Field(
@@ -48,26 +35,13 @@ class ChunkUpdate(BaseModel):
@property @property
def chunk_content(self) -> str: def chunk_content(self) -> str:
"""Get the actual content string regardless of input type""" """
Get the actual content string regardless of input type
"""
if isinstance(self.content, QAChunk): if isinstance(self.content, QAChunk):
return self.content.question # QA 模式下 page_content 存 question return f"question: {self.content.question} answer: {self.content.answer}"
return self.content return self.content
@property
def is_qa(self) -> bool:
return isinstance(self.content, QAChunk)
@property
def qa_metadata(self) -> dict:
"""返回 QA 相关的 metadata 字段"""
if isinstance(self.content, QAChunk):
return {
"chunk_type": "qa",
"question": self.content.question,
"answer": self.content.answer,
}
return {}
class ChunkRetrieve(BaseModel): class ChunkRetrieve(BaseModel):
query: str query: str
@@ -77,8 +51,3 @@ class ChunkRetrieve(BaseModel):
vector_similarity_weight: float | None = Field(None) vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None) top_k: int | None = Field(None)
retrieve_type: RetrieveType | None = Field(None) retrieve_type: RetrieveType | None = Field(None)
class ChunkBatchCreate(BaseModel):
"""批量创建 chunk"""
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")

View File

@@ -102,6 +102,11 @@ class AppDslService:
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or []) {**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
] ]
return enriched return enriched
if app_type == AppType.WORKFLOW:
enriched = {**cfg}
if "nodes" in cfg:
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
return enriched
return cfg return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]: def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
@@ -110,7 +115,7 @@ class AppDslService:
config_data = { config_data = {
"variables": config.variables if config else [], "variables": config.variables if config else [],
"edges": config.edges if config else [], "edges": config.edges if config else [],
"nodes": config.nodes if config else [], "nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
"features": config.features if config else {}, "features": config.features if config else {},
"execution_config": config.execution_config if config else {}, "execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [], "triggers": config.triggers if config else [],
@@ -190,6 +195,23 @@ class AppDslService:
def _enrich_tools(self, tools: list) -> list: def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])] return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _enrich_workflow_nodes(self, nodes: list) -> list:
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
from app.core.workflow.nodes.enums import NodeType
enriched_nodes = []
for node in (nodes or []):
node_type = node.get("type")
config = dict(node.get("config") or {})
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_id = config.get("model_id")
if model_id:
config["model_ref"] = self._model_ref(model_id)
del config["model_id"]
enriched_nodes.append({**node, "config": config})
return enriched_nodes
def _skill_ref(self, skill_id) -> Optional[dict]: def _skill_ref(self, skill_id) -> Optional[dict]:
if not skill_id: if not skill_id:
return None return None
@@ -620,16 +642,16 @@ class AppDslService:
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置") warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
config["knowledge_bases"] = resolved_kbs config["knowledge_bases"] = resolved_kbs
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value): elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_ref = config.get("model_id") model_ref = config.get("model_ref") or config.get("model_id")
if model_ref: if model_ref:
ref_dict = None ref_dict = None
if isinstance(model_ref, dict): if isinstance(model_ref, dict):
ref_id = model_ref.get("id") ref_dict = {
ref_name = model_ref.get("name") "id": model_ref.get("id"),
if ref_id: "name": model_ref.get("name"),
ref_dict = {"id": ref_id} "provider": model_ref.get("provider"),
elif ref_name is not None: "type": model_ref.get("type")
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")} }
elif isinstance(model_ref, str): elif isinstance(model_ref, str):
try: try:
uuid.UUID(model_ref) uuid.UUID(model_ref)
@@ -640,12 +662,18 @@ class AppDslService:
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings) resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
if resolved_model_id: if resolved_model_id:
config["model_id"] = resolved_model_id config["model_id"] = resolved_model_id
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
resolved_nodes.append({**node, "config": config}) resolved_nodes.append({**node, "config": config})
return resolved_nodes return resolved_nodes

View File

@@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal, qa_proposal from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory, ElasticSearchVectorFactory,
) )
@@ -311,7 +311,6 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents # 2.2 Vectorize and import batch documents
auto_questions_topn = db_document.parser_config.get("auto_questions", 0) auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
qa_prompt = db_document.parser_config.get("qa_prompt", None)
chat_model = None chat_model = None
if auto_questions_topn: if auto_questions_topn:
chat_model = Base( chat_model = Base(
@@ -319,123 +318,62 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
model_name=db_knowledge.llm.api_keys[0].model_name, model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base, base_url=db_knowledge.llm.api_keys[0].api_base,
) )
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
if qa_prompt:
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
# 预先构建所有 batch 的 chunks保证 sort_id 全局有序 # 预先构建所有 batch 的 chunks保证 sort_id 全局有序
all_batch_chunks: list[list[DocumentChunk]] = [] all_batch_chunks: list[list[DocumentChunk]] = []
if auto_questions_topn: if auto_questions_topn:
# QA 模式FastGPT 方案): # auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
# 1. 原 chunk 标记为 source保留供 GraphRAG 使用,不参与检索) # 构建 (global_idx, item) 列表
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
indexed_items = list(enumerate(res)) indexed_items = list(enumerate(res))
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]: def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)""" """为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
global_idx, item = idx_item global_idx, item = idx_item
content = item["content_with_weight"] content = item["content_with_weight"]
cache_params = {"topn": auto_questions_topn} cached = get_llm_cache(chat_model.model_name, content, "question",
if qa_prompt: {"topn": auto_questions_topn})
import hashlib
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
if not cached: if not cached:
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}") cached = question_proposal(chat_model, content, auto_questions_topn)
try: set_llm_cache(chat_model.model_name, content, cached, "question",
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt) {"topn": auto_questions_topn})
except Exception as e: return global_idx, cached
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
return global_idx, []
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
# 缓存存 JSON 字符串
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
cache_params)
return global_idx, pairs
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
if isinstance(cached, str):
try:
parsed = json.loads(cached)
if isinstance(parsed, list):
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
return global_idx, parsed
except (json.JSONDecodeError, TypeError):
pass
# 旧缓存格式(纯文本问题),尝试解析
from app.core.rag.prompts.generator import parse_qa_pairs
return global_idx, parse_qa_pairs(cached) if cached else []
return global_idx, cached if isinstance(cached, list) else []
# 并发调用 LLM 生成 QA 对 # 并发调用 LLM 生成问题
qa_map: dict[int, list] = {} question_map: dict[int, str] = {}
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor: with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
futures = {q_executor.submit(_generate_qa, item): item[0] futures = {q_executor.submit(_generate_question, item): item[0]
for item in indexed_items} for item in indexed_items}
for future in futures: for future in futures:
global_idx, pairs = future.result() global_idx, cached = future.result()
qa_map[global_idx] = pairs question_map[global_idx] = cached
progress_lines.append( progress_lines.append(
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks " f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).") f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
# 组装 chunkssource chunks + qa chunks # 按 batch 分组组装 DocumentChunk
source_chunks = [] for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
qa_chunks = [] batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
qa_sort_id = 0 chunks = []
for global_idx in range(batch_start, batch_end):
for global_idx in range(total_chunks): item = res[global_idx]
item = res[global_idx] metadata = {
source_chunk_id = uuid.uuid4().hex
# source chunk保留原文供 GraphRAG 使用,不参与向量检索
source_meta = {
"doc_id": source_chunk_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"status": 1,
"chunk_type": "source",
}
source_chunks.append(
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
# qa chunks每个 QA 对独立存储
pairs = qa_map.get(global_idx, [])
for pair in pairs:
qa_meta = {
"doc_id": uuid.uuid4().hex, "doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id), "file_id": str(db_document.file_id),
"file_name": db_document.file_name, "file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000), "file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id), "document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id), "knowledge_id": str(db_document.kb_id),
"sort_id": qa_sort_id, "sort_id": global_idx,
"status": 1, "status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
"source_chunk_id": source_chunk_id,
} }
# page_content 存 question用于向量索引 cached = question_map[global_idx]
qa_chunks.append( chunks.append(
DocumentChunk(page_content=pair["question"], metadata=qa_meta)) DocumentChunk(
qa_sort_id += 1 page_content=f"question: {cached} answer: {item['content_with_weight']}",
metadata=metadata))
# 按 batch 分组source + qa 一起) all_batch_chunks.append(chunks)
all_chunks = source_chunks + qa_chunks
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
all_batch_chunks.append(all_chunks[batch_start:batch_end])
progress_lines.append(
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
f"{len(qa_chunks)} QA chunks prepared.")
else: else:
# 无 auto_questions直接构建 chunks # 无 auto_questions直接构建 chunks
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
@@ -697,136 +635,6 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
return f"build_graphrag_for_document '{document_id}' failed: {e}" return f"build_graphrag_for_document '{document_id}' failed: {e}"
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
"""
异步导入 QA 问答对CSV/Excel
文件格式:第一行标题(跳过),第一列问题,第二列答案
"""
import csv as csv_module
import io
db = None
try:
from app.db import get_db_context
with get_db_context() as db:
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
if not db_document or not db_knowledge:
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
return {"error": "document or knowledge not found", "imported": 0}
# 1. 解析文件
qa_pairs = []
failed_rows = []
if filename.endswith(".csv"):
try:
text = contents.decode("utf-8-sig")
except UnicodeDecodeError:
text = contents.decode("gbk", errors="ignore")
sniffer = csv_module.Sniffer()
try:
dialect = sniffer.sniff(text[:2048])
delimiter = dialect.delimiter
except csv_module.Error:
delimiter = "," if "," in text[:500] else "\t"
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
for i, row in enumerate(reader):
if i == 0:
continue
if len(row) >= 2 and row[0].strip() and row[1].strip():
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
elif len(row) >= 1 and row[0].strip():
failed_rows.append(i + 1)
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
for sheet in wb.worksheets:
for i, row in enumerate(sheet.iter_rows(values_only=True)):
if i == 0:
continue
if len(row) >= 2 and row[0] and row[1]:
q = str(row[0]).strip()
a = str(row[1]).strip()
if q and a:
qa_pairs.append({"question": q, "answer": a})
elif len(row) >= 1 and row[0]:
failed_rows.append(i + 1)
wb.close()
except Exception as e:
logger.error(f"[ImportQA] Excel parse failed: {e}")
return {"error": f"Excel parse failed: {e}", "imported": 0}
if not qa_pairs:
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
return {"error": "No valid QA pairs found", "imported": 0}
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
# 2. 写入 ES
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
sort_id = 0
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for pair in qa_pairs:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": document_id,
"knowledge_id": kb_id,
"sort_id": sort_id,
"status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
}
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
batch_size = 50
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
vector_service.add_chunks(batch)
# 3. 更新 chunk_num 和 progress
db_document.chunk_num += len(chunks)
db_document.progress = 1.0
db_document.progress_msg = f"QA 导入完成: {len(chunks)}"
db.commit()
result = {"imported": len(chunks), "failed_rows": failed_rows}
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
return result
except Exception as e:
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
# 尝试更新文档状态为失败
try:
from app.db import get_db_context
with get_db_context() as err_db:
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
if doc:
doc.progress = -1.0
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
err_db.commit()
except Exception:
pass
return {"error": str(e), "imported": 0}
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
def sync_knowledge_for_kb(kb_id: uuid.UUID): def sync_knowledge_for_kb(kb_id: uuid.UUID):
""" """