Compare commits
5 Commits
feature/ra
...
feat/wxy-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cef33fce0d | ||
|
|
d9f08860bc | ||
|
|
461674c8d8 | ||
|
|
c59e179cc2 | ||
|
|
a5670bfff6 |
@@ -1,10 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import csv
|
|
||||||
import io
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -25,7 +23,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
|
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -274,9 +271,6 @@ async def create_chunk(
|
|||||||
"sort_id": sort_id,
|
"sort_id": sort_id,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
}
|
}
|
||||||
# QA chunk: 注入 chunk_type/question/answer 到 metadata
|
|
||||||
if create_data.is_qa:
|
|
||||||
metadata.update(create_data.qa_metadata)
|
|
||||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||||
# 3. Segmented vector storage
|
# 3. Segmented vector storage
|
||||||
vector_service.add_chunks([chunk])
|
vector_service.add_chunks([chunk])
|
||||||
@@ -288,187 +282,6 @@ async def create_chunk(
|
|||||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
|
||||||
async def create_chunks_batch(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
batch_data: chunk_schema.ChunkBatchCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Batch create chunks (max 8)
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
|
|
||||||
|
|
||||||
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
|
|
||||||
|
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
||||||
if not db_document:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
|
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
|
|
||||||
# Get current max sort_id
|
|
||||||
sort_id = 0
|
|
||||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
|
||||||
if items:
|
|
||||||
sort_id = items[0].metadata["sort_id"]
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for create_data in batch_data.items:
|
|
||||||
sort_id += 1
|
|
||||||
doc_id = uuid.uuid4().hex
|
|
||||||
metadata = {
|
|
||||||
"doc_id": doc_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": str(document_id),
|
|
||||||
"knowledge_id": str(kb_id),
|
|
||||||
"sort_id": sort_id,
|
|
||||||
"status": 1,
|
|
||||||
}
|
|
||||||
if create_data.is_qa:
|
|
||||||
metadata.update(create_data.qa_metadata)
|
|
||||||
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
|
|
||||||
|
|
||||||
vector_service.add_chunks(chunks)
|
|
||||||
|
|
||||||
db_document.chunk_num += len(chunks)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
|
|
||||||
async def import_qa_new_doc(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
导入 QA 问答对并新建文档(CSV/Excel),异步处理
|
|
||||||
"""
|
|
||||||
from app.schemas import file_schema, document_schema
|
|
||||||
|
|
||||||
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. 校验文件格式
|
|
||||||
filename = file.filename or ""
|
|
||||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
|
||||||
|
|
||||||
# 2. 校验知识库
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
|
||||||
|
|
||||||
# 3. 读取文件
|
|
||||||
contents = await file.read()
|
|
||||||
file_size = len(contents)
|
|
||||||
if file_size == 0:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
|
|
||||||
|
|
||||||
_, file_extension = os.path.splitext(filename)
|
|
||||||
file_ext = file_extension.lower()
|
|
||||||
|
|
||||||
# 4. 创建 File 记录
|
|
||||||
file_data = file_schema.FileCreate(
|
|
||||||
kb_id=kb_id, created_by=current_user.id,
|
|
||||||
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
|
|
||||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
|
||||||
)
|
|
||||||
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
|
|
||||||
|
|
||||||
# 5. 上传文件到存储后端
|
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
|
||||||
try:
|
|
||||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
|
|
||||||
|
|
||||||
db_file.file_key = file_key
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_file)
|
|
||||||
|
|
||||||
# 6. 创建 Document 记录(标记为 QA 类型)
|
|
||||||
doc_data = document_schema.DocumentCreate(
|
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
|
||||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
|
||||||
file_meta={}, parser_id="qa",
|
|
||||||
parser_config={"doc_type": "qa", "auto_questions": 0}
|
|
||||||
)
|
|
||||||
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
|
|
||||||
|
|
||||||
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
|
|
||||||
|
|
||||||
# 7. 派发异步任务
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
task = celery_app.send_task(
|
|
||||||
"app.core.rag.tasks.import_qa_chunks",
|
|
||||||
args=[str(kb_id), str(db_document.id), filename, contents],
|
|
||||||
queue="qa_import"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data={
|
|
||||||
"task_id": task.id,
|
|
||||||
"document_id": str(db_document.id),
|
|
||||||
"file_id": str(db_file.id),
|
|
||||||
}, msg="QA 导入任务已提交,后台处理中")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
|
|
||||||
async def import_qa_chunks(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
导入 QA 问答对(CSV/Excel),异步处理
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. 校验文件格式
|
|
||||||
filename = file.filename or ""
|
|
||||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
|
||||||
|
|
||||||
# 2. 校验知识库和文档
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
|
||||||
|
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
||||||
if not db_document:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
|
|
||||||
|
|
||||||
# 3. 读取文件内容,派发异步任务
|
|
||||||
contents = await file.read()
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
task = celery_app.send_task(
|
|
||||||
"app.core.rag.tasks.import_qa_chunks",
|
|
||||||
args=[str(kb_id), str(document_id), filename, contents],
|
|
||||||
queue="qa_import"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -529,9 +342,6 @@ async def update_chunk(
|
|||||||
if total:
|
if total:
|
||||||
chunk = items[0]
|
chunk = items[0]
|
||||||
chunk.page_content = content
|
chunk.page_content = content
|
||||||
# QA chunk: 更新 metadata 中的 question/answer
|
|
||||||
if update_data.is_qa:
|
|
||||||
chunk.metadata.update(update_data.qa_metadata)
|
|
||||||
vector_service.update_by_segment(chunk)
|
vector_service.update_by_segment(chunk)
|
||||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||||
else:
|
else:
|
||||||
@@ -546,7 +356,6 @@ async def delete_chunk(
|
|||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -564,7 +373,7 @@ async def delete_chunk(
|
|||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
if vector_service.text_exists(doc_id):
|
if vector_service.text_exists(doc_id):
|
||||||
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
|
vector_service.delete_by_ids([doc_id])
|
||||||
# 更新 chunk_num
|
# 更新 chunk_num
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
db_document.chunk_num -= 1
|
db_document.chunk_num -= 1
|
||||||
|
|||||||
@@ -113,33 +113,6 @@ async def create_chunk(
|
|||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
|
||||||
@require_api_key(scopes=["rag"])
|
|
||||||
async def create_chunks_batch(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
items: list = Body(..., description="chunk items list"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Batch create chunks (max 8)
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
batch_data = chunk_schema.ChunkBatchCreate(**body)
|
|
||||||
# 0. Obtain the creator of the api key
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
|
|
||||||
document_id=document_id,
|
|
||||||
batch_data=batch_data,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
@require_api_key(scopes=["rag"])
|
@require_api_key(scopes=["rag"])
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
@@ -203,7 +176,6 @@ async def delete_chunk(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
delete document chunk
|
delete document chunk
|
||||||
@@ -216,7 +188,6 @@ async def delete_chunk(
|
|||||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
force_refresh=force_refresh,
|
|
||||||
db=db,
|
db=db,
|
||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ class Settings:
|
|||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
|
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
|
|||||||
@@ -46,10 +46,7 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
|
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
|
||||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
|
||||||
if d.get("chunk_type") == "qa":
|
|
||||||
continue
|
|
||||||
chunks.append(d["page_content"])
|
chunks.append(d["page_content"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@@ -153,9 +150,6 @@ async def run_graphrag_for_kb(
|
|||||||
|
|
||||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
||||||
for doc in items:
|
for doc in items:
|
||||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
|
||||||
if (doc.metadata or {}).get("chunk_type") == "qa":
|
|
||||||
continue
|
|
||||||
content = doc.page_content
|
content = doc.page_content
|
||||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||||
current_chunk += content
|
current_chunk += content
|
||||||
|
|||||||
@@ -131,52 +131,18 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||||||
|
|
||||||
|
|
||||||
def question_proposal(chat_mdl, content, topn=3):
|
def question_proposal(chat_mdl, content, topn=3):
|
||||||
"""生成问题(向后兼容,返回纯文本问题列表)"""
|
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||||
pairs = qa_proposal(chat_mdl, content, topn)
|
rendered_prompt = template.render(content=content, topn=topn)
|
||||||
if not pairs:
|
|
||||||
return ""
|
|
||||||
return "\n".join([p["question"] for p in pairs])
|
|
||||||
|
|
||||||
|
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||||
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
|
|
||||||
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_mdl: LLM 模型
|
|
||||||
content: 文本内容
|
|
||||||
topn: 生成 QA 对数量
|
|
||||||
custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn)
|
|
||||||
"""
|
|
||||||
if custom_prompt:
|
|
||||||
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
|
|
||||||
sys_prompt = template.render(topn=topn)
|
|
||||||
else:
|
|
||||||
sys_prompt = QUESTION_PROMPT_TEMPLATE
|
|
||||||
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
|
|
||||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||||
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(raw, tuple):
|
if isinstance(kwd, tuple):
|
||||||
raw = raw[0]
|
kwd = kwd[0]
|
||||||
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
if raw.find("**ERROR**") >= 0:
|
if kwd.find("**ERROR**") >= 0:
|
||||||
return []
|
return ""
|
||||||
return parse_qa_pairs(raw)
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
def parse_qa_pairs(text: str) -> list:
|
|
||||||
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
|
|
||||||
pairs = []
|
|
||||||
for line in text.strip().split("\n"):
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
# 匹配 Q: ... A: ... 格式
|
|
||||||
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
q, a = match.group(1).strip(), match.group(2).strip()
|
|
||||||
if q and a:
|
|
||||||
pairs.append({"question": q, "answer": a})
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
|
|
||||||
def graph_entity_types(chat_mdl, scenario):
|
def graph_entity_types(chat_mdl, scenario):
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
## Role
|
## Role
|
||||||
You are a text analyzer and knowledge extraction expert.
|
You are a text analyzer.
|
||||||
|
|
||||||
## Task
|
## Task
|
||||||
Generate question-answer pairs from the given text content.
|
Propose {{ topn }} questions about a given piece of text content.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs.
|
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
||||||
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
|
|
||||||
- The questions SHOULD NOT have overlapping meanings.
|
- The questions SHOULD NOT have overlapping meanings.
|
||||||
- The questions SHOULD cover the main content of the text as much as possible.
|
- The questions SHOULD cover the main content of the text as much as possible.
|
||||||
- The answers MUST be concise, accurate, and directly derived from the text content.
|
- The questions MUST be in the same language as the given piece of text content.
|
||||||
- The answers SHOULD be self-contained and understandable without additional context.
|
- One question per line.
|
||||||
- Both questions and answers MUST be in the same language as the given text content.
|
- Output questions ONLY.
|
||||||
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
|
|
||||||
- Output question-answer pairs ONLY, no extra explanation or commentary.
|
---
|
||||||
|
|
||||||
|
## Text Content
|
||||||
|
{{ content }}
|
||||||
|
|
||||||
## Example Output
|
|
||||||
Q: What is the capital of France? A: The capital of France is Paris.
|
|
||||||
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch, helpers, NotFoundError
|
from elasticsearch import Elasticsearch, helpers
|
||||||
from elasticsearch.helpers import BulkIndexError
|
from elasticsearch.helpers import BulkIndexError
|
||||||
from packaging.version import parse as parse_version
|
from packaging.version import parse as parse_version
|
||||||
# langchain-community
|
# langchain-community
|
||||||
@@ -53,30 +53,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
return "elasticsearch"
|
return "elasticsearch"
|
||||||
|
|
||||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||||
# QA chunks: embedding 只对 question 字段做;source chunks: 不做 embedding
|
# 实现 Elasticsearch 保存向量
|
||||||
texts_for_embedding = []
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
for chunk in chunks:
|
|
||||||
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
|
|
||||||
if chunk_type == "source":
|
|
||||||
# source chunk 不需要向量索引
|
|
||||||
texts_for_embedding.append("")
|
|
||||||
elif chunk_type == "qa":
|
|
||||||
# QA chunk: 用 question 字段做 embedding
|
|
||||||
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
|
|
||||||
else:
|
|
||||||
# 普通 chunk: 用 page_content 做 embedding
|
|
||||||
texts_for_embedding.append(chunk.page_content)
|
|
||||||
|
|
||||||
if self.is_multimodal_embedding:
|
if self.is_multimodal_embedding:
|
||||||
embeddings = self.embeddings.embed_batch(texts_for_embedding)
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = self.embeddings.embed_batch(texts)
|
||||||
else:
|
else:
|
||||||
embeddings = self.embeddings.embed_documents(texts_for_embedding)
|
embeddings = self.embeddings.embed_documents(list(texts))
|
||||||
|
|
||||||
# source chunk 的向量置空
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
if (chunk.metadata or {}).get("chunk_type") == "source":
|
|
||||||
embeddings[i] = None
|
|
||||||
|
|
||||||
self.create(chunks, embeddings, **kwargs)
|
self.create(chunks, embeddings, **kwargs)
|
||||||
|
|
||||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||||
@@ -89,25 +72,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
uuids = self._get_uuids(chunks)
|
uuids = self._get_uuids(chunks)
|
||||||
actions = []
|
actions = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
source = {
|
|
||||||
Field.CONTENT_KEY.value: chunk.page_content,
|
|
||||||
Field.METADATA_KEY.value: chunk.metadata or {},
|
|
||||||
Field.VECTOR.value: embeddings[i] or None
|
|
||||||
}
|
|
||||||
# 写入 QA 相关字段
|
|
||||||
meta = chunk.metadata or {}
|
|
||||||
if meta.get("chunk_type"):
|
|
||||||
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
|
|
||||||
if meta.get("question"):
|
|
||||||
source[Field.QUESTION.value] = meta["question"]
|
|
||||||
if meta.get("answer"):
|
|
||||||
source[Field.ANSWER.value] = meta["answer"]
|
|
||||||
if meta.get("source_chunk_id"):
|
|
||||||
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
|
|
||||||
|
|
||||||
action = {
|
action = {
|
||||||
"_index": self._collection_name,
|
"_index": self._collection_name,
|
||||||
"_source": source
|
"_source": {
|
||||||
|
Field.CONTENT_KEY.value: chunk.page_content,
|
||||||
|
Field.METADATA_KEY.value: chunk.metadata or {},
|
||||||
|
Field.VECTOR.value: embeddings[i] or None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
# using bulk mode
|
# using bulk mode
|
||||||
@@ -142,7 +113,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
def delete_by_ids(self, ids: list[str]):
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
@@ -163,8 +134,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
if refresh:
|
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -184,7 +153,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
return False
|
return False
|
||||||
actual_ids = self.get_ids_by_metadata_field(key, value)
|
actual_ids = self.get_ids_by_metadata_field(key, value)
|
||||||
@@ -193,8 +162,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
if refresh:
|
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -225,8 +192,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
||||||
if not self._client.indices.exists(index=indices):
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
# Calculate the start position for the current page
|
# Calculate the start position for the current page
|
||||||
from_ = pagesize * (page-1)
|
from_ = pagesize * (page-1)
|
||||||
@@ -261,15 +226,12 @@ class ElasticSearchVector(BaseVector):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
||||||
try:
|
result = self._client.search(
|
||||||
result = self._client.search(
|
index=indices,
|
||||||
index=indices,
|
from_=from_, # Only use from_ for the first page (simplified)
|
||||||
from_=from_, # Only use from_ for the first page (simplified)
|
size=pagesize,
|
||||||
size=pagesize,
|
body=query_str,
|
||||||
body=query_str,
|
)
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -279,19 +241,10 @@ class ElasticSearchVector(BaseVector):
|
|||||||
for res in result["hits"]["hits"]:
|
for res in result["hits"]["hits"]:
|
||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
|
# vector = source.get(Field.VECTOR.value)
|
||||||
vector = None
|
vector = None
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
|
|
||||||
# 将 QA 字段注入 metadata 供前端展示
|
|
||||||
if chunk_type:
|
|
||||||
metadata["chunk_type"] = chunk_type
|
|
||||||
if chunk_type == "qa":
|
|
||||||
metadata["question"] = source.get(Field.QUESTION.value, "")
|
|
||||||
metadata["answer"] = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
|
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -314,18 +267,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
if not self._client.indices.exists(index=indices):
|
|
||||||
return 0, []
|
|
||||||
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
||||||
try:
|
result = self._client.search(
|
||||||
result = self._client.search(
|
index=indices,
|
||||||
index=indices,
|
from_=0, # Only use from_ for the first page (simplified)
|
||||||
from_=0, # Only use from_ for the first page (simplified)
|
size=1,
|
||||||
size=1,
|
body=query_str,
|
||||||
body=query_str,
|
)
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
return 0, []
|
|
||||||
# print(result)
|
# print(result)
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -360,43 +308,27 @@ class ElasticSearchVector(BaseVector):
|
|||||||
Returns:
|
Returns:
|
||||||
updated count.
|
updated count.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name)
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
chunk_type = (chunk.metadata or {}).get("chunk_type")
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
# QA chunk: embedding 基于 question;source chunk: 不更新向量
|
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||||
if chunk_type == "source":
|
|
||||||
embed_text = ""
|
|
||||||
elif chunk_type == "qa":
|
|
||||||
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
|
|
||||||
else:
|
else:
|
||||||
embed_text = chunk.page_content
|
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||||
|
|
||||||
if chunk_type != "source":
|
|
||||||
if self.is_multimodal_embedding:
|
|
||||||
chunk.vector = self.embeddings.embed_text(embed_text)
|
|
||||||
else:
|
|
||||||
chunk.vector = self.embeddings.embed_query(embed_text)
|
|
||||||
|
|
||||||
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
|
|
||||||
params = {
|
|
||||||
"new_content": chunk.page_content,
|
|
||||||
"new_vector": chunk.vector if chunk_type != "source" else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# QA chunk: 同时更新 question/answer 字段
|
|
||||||
if chunk_type == "qa":
|
|
||||||
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
|
|
||||||
params["new_question"] = (chunk.metadata or {}).get("question", "")
|
|
||||||
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
|
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"script": {
|
"script": {
|
||||||
"source": script_source,
|
"source": """
|
||||||
"params": params
|
ctx._source.page_content = params.new_content;
|
||||||
|
ctx._source.vector = params.new_vector;
|
||||||
|
""",
|
||||||
|
"params": {
|
||||||
|
"new_content": chunk.page_content,
|
||||||
|
"new_vector": chunk.vector
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"term": {
|
"term": {
|
||||||
Field.DOC_ID.value: chunk.metadata["doc_id"]
|
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -404,6 +336,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
index=indices,
|
index=indices,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
|
# Remove debug printing and use logging instead
|
||||||
|
# print(result)
|
||||||
|
# print(f"Update successful, number of affected documents: {result['updated']}")
|
||||||
return result['updated']
|
return result['updated']
|
||||||
|
|
||||||
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
||||||
@@ -462,11 +397,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": { # Add the filter condition of status=1
|
||||||
{"term": {"metadata.status": 1}},
|
"term": {
|
||||||
# 排除 source chunk(仅供 GraphRAG 使用,不参与检索)
|
"metadata.status": 1
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
}
|
||||||
]
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# If file_names_filter is passed in, merge the filtering conditions
|
# If file_names_filter is passed in, merge the filtering conditions
|
||||||
@@ -480,14 +415,22 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
||||||
|
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
|
||||||
"params": {"query_vector": query_vector}
|
"params": {"query_vector": query_vector}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{"term": {"metadata.status": 1}},
|
{
|
||||||
{"terms": {"metadata.file_name": file_names_filter}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"terms": {
|
||||||
|
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -508,19 +451,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
score = score / 2 # Normalized [0-1]
|
score = score / 2 # Normalized [0-1]
|
||||||
|
|
||||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
|
||||||
if chunk_type == "qa":
|
|
||||||
question = source.get(Field.QUESTION.value, "")
|
|
||||||
answer = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {question}\nA: {answer}"
|
|
||||||
metadata["chunk_type"] = "qa"
|
|
||||||
metadata["question"] = question
|
|
||||||
metadata["answer"] = answer
|
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -559,10 +491,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": { # Add the filter condition of status=1
|
||||||
{"term": {"metadata.status": 1}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
]
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -579,9 +512,16 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{"term": {"metadata.status": 1}},
|
{
|
||||||
{"terms": {"metadata.file_name": file_names_filter}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"terms": {
|
||||||
|
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -603,17 +543,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
|
|
||||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
|
||||||
if chunk_type == "qa":
|
|
||||||
question = source.get(Field.QUESTION.value, "")
|
|
||||||
answer = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {question}\nA: {answer}"
|
|
||||||
metadata["chunk_type"] = "qa"
|
|
||||||
metadata["question"] = question
|
|
||||||
metadata["answer"] = answer
|
|
||||||
|
|
||||||
# Normalize the score to the [0,1] interval
|
# Normalize the score to the [0,1] interval
|
||||||
normalized_score = res["_score"] / max_score
|
normalized_score = res["_score"] / max_score
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
||||||
@@ -723,7 +652,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
Field.VECTOR.value: {
|
Field.VECTOR.value: {
|
||||||
"type": "dense_vector",
|
"type": "dense_vector",
|
||||||
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度,fallback 768
|
"dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
|
||||||
"index": True,
|
"index": True,
|
||||||
"similarity": "cosine"
|
"similarity": "cosine"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,3 @@ class Field(StrEnum):
|
|||||||
DOCUMENT_ID = "metadata.document_id"
|
DOCUMENT_ID = "metadata.document_id"
|
||||||
KNOWLEDGE_ID = "metadata.knowledge_id"
|
KNOWLEDGE_ID = "metadata.knowledge_id"
|
||||||
SORT_ID = "metadata.sort_id"
|
SORT_ID = "metadata.sort_id"
|
||||||
# QA fields
|
|
||||||
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
SOURCE_CHUNK_ID = "source_chunk_id"
|
|
||||||
|
|||||||
@@ -27,14 +27,14 @@ class BaseVector(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
def delete_by_ids(self, ids: list[str]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Author: Eternity
|
# Author: Eternity
|
||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/10 13:33
|
# @Time : 2026/2/10 13:33
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -141,9 +142,10 @@ class GraphBuilder:
|
|||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
branch_nodes.append(
|
if node_info.get("branch") is not None:
|
||||||
(node_info["id"], node_info["branch"])
|
branch_nodes.append(
|
||||||
)
|
(node_info["id"], node_info["branch"])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||||
output_nodes.append(node_info["id"])
|
output_nodes.append(node_info["id"])
|
||||||
@@ -314,9 +316,12 @@ class GraphBuilder:
|
|||||||
for idx in range(len(related_edge)):
|
for idx in range(len(related_edge)):
|
||||||
# Generate a condition expression for each edge
|
# Generate a condition expression for each edge
|
||||||
# Used later to determine which branch to take based on the node's output
|
# Used later to determine which branch to take based on the node's output
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
# For other branch nodes (e.g. HTTP), use output field
|
||||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
|
||||||
|
related_edge[idx]['condition'] = (
|
||||||
|
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
|
||||||
|
)
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
|
|||||||
@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
self._input_data: dict[str, Any] | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
|
||||||
|
if self._input_data is not None:
|
||||||
|
return self._input_data
|
||||||
|
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the assignment operation defined by this node.
|
Execute the assignment operation defined by this node.
|
||||||
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
None or the result of the assignment operation.
|
None or the result of the assignment operation.
|
||||||
"""
|
"""
|
||||||
|
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
|
||||||
|
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
|
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
self.typed_config = AssignerNodeConfig(**self.config)
|
self.typed_config = AssignerNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} 开始执行")
|
logger.info(f"节点 {self.node_id} 开始执行")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 匹配模板变量 {{xxx}} 的正则
|
||||||
|
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionError(Exception):
|
class NodeExecutionError(Exception):
|
||||||
"""节点执行失败异常。
|
"""节点执行失败异常。
|
||||||
@@ -503,10 +507,29 @@ class BaseNode(ABC):
|
|||||||
variable_pool: The variable pool used for reading and writing variables.
|
variable_pool: The variable pool used for reading and writing variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the node's input data.
|
A dictionary containing the node's input data with all template
|
||||||
|
variables resolved to their actual runtime values.
|
||||||
"""
|
"""
|
||||||
# Default implementation returns the node configuration
|
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
return {"config": self.config}
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
||||||
|
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 节点的原始配置(可能包含模板变量)。
|
||||||
|
variable_pool: 变量池,用于解析模板变量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
||||||
|
"""
|
||||||
|
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
||||||
|
return BaseNode._render_template(config, variable_pool, strict=False)
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
||||||
|
elif isinstance(config, list):
|
||||||
|
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
||||||
|
return config
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
"""Extracts the actual output from the business result.
|
"""Extracts the actual output from the business result.
|
||||||
|
|||||||
@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
|
|||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
return {"file_selector": self.config.get("file_selector")}
|
file_selector = self.config.get("file_selector", "")
|
||||||
|
# 将变量选择器(如 sys.files)解析为实际值
|
||||||
|
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
||||||
|
return {"file_selector": resolved}
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
config = DocExtractorNodeConfig(**self.config)
|
config = DocExtractorNodeConfig(**self.config)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class NodeType(StrEnum):
|
|||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import uuid
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||||
|
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMErrorHandleConfig(BaseModel):
|
||||||
|
"""LLM 异常处理配置"""
|
||||||
|
|
||||||
|
method: HttpErrorHandle = Field(
|
||||||
|
default=HttpErrorHandle.NONE,
|
||||||
|
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
|
||||||
|
)
|
||||||
|
|
||||||
|
output: str = Field(
|
||||||
|
default="",
|
||||||
|
description="LLM 异常时返回的默认输出文本(method=default 时生效)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeConfig(BaseNodeConfig):
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
"""LLM 节点配置
|
"""LLM 节点配置
|
||||||
|
|
||||||
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_handle: LLMErrorHandleConfig = Field(
|
||||||
|
default_factory=LLMErrorHandleConfig,
|
||||||
|
description="LLM 异常处理配置",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v):
|
def validate_input_mode(cls, v):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
|
||||||
|
|
||||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||||
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
|
|||||||
variable_pool: 变量池
|
variable_pool: 变量池
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
|
||||||
|
{"llm_result": None, "branch_signal": "ERROR"} on branch error
|
||||||
"""
|
"""
|
||||||
# self.typed_config = LLMNodeConfig(**self.config)
|
try:
|
||||||
llm = await self._prepare_llm(state, variable_pool, False)
|
# self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
llm = await self._prepare_llm(state, variable_pool, False)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(self.messages)
|
response = await llm.ainvoke(self.messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
content = self.process_model_output(response.content)
|
content = self.process_model_output(response.content)
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
# 返回 AIMessage(包含响应元数据)
|
# 返回 AIMessage(包含响应元数据)
|
||||||
return AIMessage(content=content, response_metadata={
|
return {
|
||||||
**response.response_metadata,
|
"llm_result": AIMessage(content=content, response_metadata={
|
||||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
**response.response_metadata,
|
||||||
})
|
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||||
|
}),
|
||||||
|
"branch_signal": "SUCCESS",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
|
||||||
|
return self._handle_llm_error(e)
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> str:
|
def _extract_output(self, business_result: Any) -> dict:
|
||||||
"""从 AIMessage 中提取文本内容"""
|
"""从业务结果中提取输出变量
|
||||||
|
|
||||||
|
支持新旧两种格式:
|
||||||
|
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
|
||||||
|
- 旧格式:AIMessage(向后兼容)
|
||||||
|
"""
|
||||||
|
if isinstance(business_result, dict) and "branch_signal" in business_result:
|
||||||
|
llm_result = business_result.get("llm_result")
|
||||||
|
if isinstance(llm_result, AIMessage):
|
||||||
|
return {
|
||||||
|
"output": llm_result.content,
|
||||||
|
"branch_signal": business_result["branch_signal"],
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"output": str(llm_result) if llm_result else "",
|
||||||
|
"branch_signal": business_result["branch_signal"],
|
||||||
|
}
|
||||||
|
# 旧格式向后兼容
|
||||||
if isinstance(business_result, AIMessage):
|
if isinstance(business_result, AIMessage):
|
||||||
return business_result.content
|
return {"output": business_result.content, "branch_signal": "SUCCESS"}
|
||||||
return str(business_result)
|
return {"output": str(business_result), "branch_signal": "SUCCESS"}
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从 AIMessage 中提取 token 使用情况"""
|
"""从业务结果中提取 token 使用情况"""
|
||||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
llm_result = business_result
|
||||||
usage = business_result.response_metadata.get('token_usage')
|
if isinstance(business_result, dict):
|
||||||
|
llm_result = business_result.get("llm_result", business_result)
|
||||||
|
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
|
||||||
|
usage = llm_result.response_metadata.get('token_usage')
|
||||||
if usage:
|
if usage:
|
||||||
return {
|
return {
|
||||||
"prompt_tokens": usage.get('input_tokens', 0),
|
"prompt_tokens": usage.get('input_tokens', 0),
|
||||||
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _handle_llm_error(self, error: Exception) -> dict:
|
||||||
|
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: LLM 调用中捕获的异常
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
|
||||||
|
or default output for default mode
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
原异常(当 error_handle.method 为 NONE 时)
|
||||||
|
"""
|
||||||
|
if self.typed_config is None:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
match self.typed_config.error_handle.method:
|
||||||
|
case HttpErrorHandle.NONE:
|
||||||
|
raise error
|
||||||
|
case HttpErrorHandle.DEFAULT:
|
||||||
|
logger.warning(
|
||||||
|
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
|
||||||
|
)
|
||||||
|
default_output = self.typed_config.error_handle.output or ""
|
||||||
|
return {
|
||||||
|
"llm_result": AIMessage(content=default_output, response_metadata={}),
|
||||||
|
"branch_signal": "SUCCESS",
|
||||||
|
}
|
||||||
|
case HttpErrorHandle.BRANCH:
|
||||||
|
logger.warning(
|
||||||
|
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"llm_result": None,
|
||||||
|
"branch_signal": "ERROR",
|
||||||
|
}
|
||||||
|
raise error
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
"""流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
llm = await self._prepare_llm(state, variable_pool, True)
|
try:
|
||||||
|
llm = await self._prepare_llm(state, variable_pool, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
last_meta_data = {}
|
last_meta_data = {}
|
||||||
last_usage_metadata = {}
|
last_usage_metadata = {}
|
||||||
async for chunk in llm.astream(self.messages):
|
async for chunk in llm.astream(self.messages):
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
content = self.process_model_output(chunk.content)
|
content = self.process_model_output(chunk.content)
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||||
last_meta_data = chunk.response_metadata
|
last_meta_data = chunk.response_metadata
|
||||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||||
last_usage_metadata = chunk.usage_metadata
|
last_usage_metadata = chunk.usage_metadata
|
||||||
|
|
||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield {
|
yield {
|
||||||
"__final__": False,
|
"__final__": False,
|
||||||
"chunk": content
|
"chunk": content
|
||||||
}
|
}
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"__final__": False,
|
"__final__": False,
|
||||||
"chunk": "",
|
"chunk": "",
|
||||||
"done": True
|
"done": True
|
||||||
}
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
|
||||||
final_message = AIMessage(
|
|
||||||
content=full_response,
|
|
||||||
response_metadata={
|
|
||||||
**last_meta_data,
|
|
||||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
|
||||||
}
|
}
|
||||||
)
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# yield 完成标记
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
yield {"__final__": True, "result": final_message}
|
final_message = AIMessage(
|
||||||
|
content=full_response,
|
||||||
|
response_metadata={
|
||||||
|
**last_meta_data,
|
||||||
|
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# yield 完成标记
|
||||||
|
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
|
||||||
|
error_result = self._handle_llm_error(e)
|
||||||
|
yield {"__final__": True, "result": error_result}
|
||||||
|
|||||||
@@ -20,26 +20,13 @@ class ChunkCreate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""Get the actual content string regardless of input type"""
|
"""
|
||||||
|
Get the actual content string regardless of input type
|
||||||
|
"""
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return self.content.question # QA 模式下 page_content 存 question
|
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
@property
|
|
||||||
def is_qa(self) -> bool:
|
|
||||||
return isinstance(self.content, QAChunk)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def qa_metadata(self) -> dict:
|
|
||||||
"""返回 QA 相关的 metadata 字段"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
|
||||||
return {
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": self.content.question,
|
|
||||||
"answer": self.content.answer,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkUpdate(BaseModel):
|
class ChunkUpdate(BaseModel):
|
||||||
content: Union[str, QAChunk] = Field(
|
content: Union[str, QAChunk] = Field(
|
||||||
@@ -48,26 +35,13 @@ class ChunkUpdate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""Get the actual content string regardless of input type"""
|
"""
|
||||||
|
Get the actual content string regardless of input type
|
||||||
|
"""
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return self.content.question # QA 模式下 page_content 存 question
|
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
@property
|
|
||||||
def is_qa(self) -> bool:
|
|
||||||
return isinstance(self.content, QAChunk)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def qa_metadata(self) -> dict:
|
|
||||||
"""返回 QA 相关的 metadata 字段"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
|
||||||
return {
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": self.content.question,
|
|
||||||
"answer": self.content.answer,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkRetrieve(BaseModel):
|
class ChunkRetrieve(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
@@ -77,8 +51,3 @@ class ChunkRetrieve(BaseModel):
|
|||||||
vector_similarity_weight: float | None = Field(None)
|
vector_similarity_weight: float | None = Field(None)
|
||||||
top_k: int | None = Field(None)
|
top_k: int | None = Field(None)
|
||||||
retrieve_type: RetrieveType | None = Field(None)
|
retrieve_type: RetrieveType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ChunkBatchCreate(BaseModel):
|
|
||||||
"""批量创建 chunk"""
|
|
||||||
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")
|
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ class AppDslService:
|
|||||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||||
]
|
]
|
||||||
return enriched
|
return enriched
|
||||||
|
if app_type == AppType.WORKFLOW:
|
||||||
|
enriched = {**cfg}
|
||||||
|
if "nodes" in cfg:
|
||||||
|
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
||||||
|
return enriched
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
@@ -110,7 +115,7 @@ class AppDslService:
|
|||||||
config_data = {
|
config_data = {
|
||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"edges": config.edges if config else [],
|
"edges": config.edges if config else [],
|
||||||
"nodes": config.nodes if config else [],
|
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
||||||
"features": config.features if config else {},
|
"features": config.features if config else {},
|
||||||
"execution_config": config.execution_config if config else {},
|
"execution_config": config.execution_config if config else {},
|
||||||
"triggers": config.triggers if config else [],
|
"triggers": config.triggers if config else [],
|
||||||
@@ -190,6 +195,23 @@ class AppDslService:
|
|||||||
def _enrich_tools(self, tools: list) -> list:
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
|
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
||||||
|
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
enriched_nodes = []
|
||||||
|
for node in (nodes or []):
|
||||||
|
node_type = node.get("type")
|
||||||
|
config = dict(node.get("config") or {})
|
||||||
|
|
||||||
|
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
|
model_id = config.get("model_id")
|
||||||
|
if model_id:
|
||||||
|
config["model_ref"] = self._model_ref(model_id)
|
||||||
|
del config["model_id"]
|
||||||
|
|
||||||
|
enriched_nodes.append({**node, "config": config})
|
||||||
|
return enriched_nodes
|
||||||
|
|
||||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||||
if not skill_id:
|
if not skill_id:
|
||||||
return None
|
return None
|
||||||
@@ -620,16 +642,16 @@ class AppDslService:
|
|||||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||||
config["knowledge_bases"] = resolved_kbs
|
config["knowledge_bases"] = resolved_kbs
|
||||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
model_ref = config.get("model_id")
|
model_ref = config.get("model_ref") or config.get("model_id")
|
||||||
if model_ref:
|
if model_ref:
|
||||||
ref_dict = None
|
ref_dict = None
|
||||||
if isinstance(model_ref, dict):
|
if isinstance(model_ref, dict):
|
||||||
ref_id = model_ref.get("id")
|
ref_dict = {
|
||||||
ref_name = model_ref.get("name")
|
"id": model_ref.get("id"),
|
||||||
if ref_id:
|
"name": model_ref.get("name"),
|
||||||
ref_dict = {"id": ref_id}
|
"provider": model_ref.get("provider"),
|
||||||
elif ref_name is not None:
|
"type": model_ref.get("type")
|
||||||
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
}
|
||||||
elif isinstance(model_ref, str):
|
elif isinstance(model_ref, str):
|
||||||
try:
|
try:
|
||||||
uuid.UUID(model_ref)
|
uuid.UUID(model_ref)
|
||||||
@@ -640,12 +662,18 @@ class AppDslService:
|
|||||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||||
if resolved_model_id:
|
if resolved_model_id:
|
||||||
config["model_id"] = resolved_model_id
|
config["model_id"] = resolved_model_id
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
resolved_nodes.append({**node, "config": config})
|
resolved_nodes.append({**node, "config": config})
|
||||||
return resolved_nodes
|
return resolved_nodes
|
||||||
|
|
||||||
|
|||||||
254
api/app/tasks.py
254
api/app/tasks.py
@@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV
|
|||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.prompts.generator import question_proposal, qa_proposal
|
from app.core.rag.prompts.generator import question_proposal
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
ElasticSearchVectorFactory,
|
ElasticSearchVectorFactory,
|
||||||
)
|
)
|
||||||
@@ -311,7 +311,6 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||||
# 2.2 Vectorize and import batch documents
|
# 2.2 Vectorize and import batch documents
|
||||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||||
qa_prompt = db_document.parser_config.get("qa_prompt", None)
|
|
||||||
chat_model = None
|
chat_model = None
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
@@ -319,123 +318,62 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
)
|
)
|
||||||
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
|
|
||||||
if qa_prompt:
|
|
||||||
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
|
|
||||||
|
|
||||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||||
|
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
# QA 模式(FastGPT 方案):
|
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
||||||
# 1. 原 chunk 标记为 source(保留供 GraphRAG 使用,不参与检索)
|
# 构建 (global_idx, item) 列表
|
||||||
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
|
|
||||||
indexed_items = list(enumerate(res))
|
indexed_items = list(enumerate(res))
|
||||||
|
|
||||||
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]:
|
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
||||||
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)"""
|
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
||||||
global_idx, item = idx_item
|
global_idx, item = idx_item
|
||||||
content = item["content_with_weight"]
|
content = item["content_with_weight"]
|
||||||
cache_params = {"topn": auto_questions_topn}
|
cached = get_llm_cache(chat_model.model_name, content, "question",
|
||||||
if qa_prompt:
|
{"topn": auto_questions_topn})
|
||||||
import hashlib
|
|
||||||
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
|
|
||||||
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
|
|
||||||
if not cached:
|
if not cached:
|
||||||
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}")
|
cached = question_proposal(chat_model, content, auto_questions_topn)
|
||||||
try:
|
set_llm_cache(chat_model.model_name, content, cached, "question",
|
||||||
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
|
{"topn": auto_questions_topn})
|
||||||
except Exception as e:
|
return global_idx, cached
|
||||||
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
|
|
||||||
return global_idx, []
|
|
||||||
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
|
|
||||||
# 缓存存 JSON 字符串
|
|
||||||
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
|
|
||||||
cache_params)
|
|
||||||
return global_idx, pairs
|
|
||||||
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
|
|
||||||
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
|
|
||||||
if isinstance(cached, str):
|
|
||||||
try:
|
|
||||||
parsed = json.loads(cached)
|
|
||||||
if isinstance(parsed, list):
|
|
||||||
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
|
|
||||||
return global_idx, parsed
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
# 旧缓存格式(纯文本问题),尝试解析
|
|
||||||
from app.core.rag.prompts.generator import parse_qa_pairs
|
|
||||||
return global_idx, parse_qa_pairs(cached) if cached else []
|
|
||||||
return global_idx, cached if isinstance(cached, list) else []
|
|
||||||
|
|
||||||
# 并发调用 LLM 生成 QA 对
|
# 并发调用 LLM 生成问题
|
||||||
qa_map: dict[int, list] = {}
|
question_map: dict[int, str] = {}
|
||||||
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||||
futures = {q_executor.submit(_generate_qa, item): item[0]
|
futures = {q_executor.submit(_generate_question, item): item[0]
|
||||||
for item in indexed_items}
|
for item in indexed_items}
|
||||||
for future in futures:
|
for future in futures:
|
||||||
global_idx, pairs = future.result()
|
global_idx, cached = future.result()
|
||||||
qa_map[global_idx] = pairs
|
question_map[global_idx] = cached
|
||||||
|
|
||||||
progress_lines.append(
|
progress_lines.append(
|
||||||
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks "
|
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
||||||
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||||
|
|
||||||
# 组装 chunks:source chunks + qa chunks
|
# 按 batch 分组组装 DocumentChunk
|
||||||
source_chunks = []
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
qa_chunks = []
|
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||||
qa_sort_id = 0
|
chunks = []
|
||||||
|
for global_idx in range(batch_start, batch_end):
|
||||||
for global_idx in range(total_chunks):
|
item = res[global_idx]
|
||||||
item = res[global_idx]
|
metadata = {
|
||||||
source_chunk_id = uuid.uuid4().hex
|
|
||||||
|
|
||||||
# source chunk:保留原文,供 GraphRAG 使用,不参与向量检索
|
|
||||||
source_meta = {
|
|
||||||
"doc_id": source_chunk_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": str(db_document.id),
|
|
||||||
"knowledge_id": str(db_document.kb_id),
|
|
||||||
"sort_id": global_idx,
|
|
||||||
"status": 1,
|
|
||||||
"chunk_type": "source",
|
|
||||||
}
|
|
||||||
source_chunks.append(
|
|
||||||
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
|
|
||||||
|
|
||||||
# qa chunks:每个 QA 对独立存储
|
|
||||||
pairs = qa_map.get(global_idx, [])
|
|
||||||
for pair in pairs:
|
|
||||||
qa_meta = {
|
|
||||||
"doc_id": uuid.uuid4().hex,
|
"doc_id": uuid.uuid4().hex,
|
||||||
"file_id": str(db_document.file_id),
|
"file_id": str(db_document.file_id),
|
||||||
"file_name": db_document.file_name,
|
"file_name": db_document.file_name,
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
"document_id": str(db_document.id),
|
"document_id": str(db_document.id),
|
||||||
"knowledge_id": str(db_document.kb_id),
|
"knowledge_id": str(db_document.kb_id),
|
||||||
"sort_id": qa_sort_id,
|
"sort_id": global_idx,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": pair["question"],
|
|
||||||
"answer": pair["answer"],
|
|
||||||
"source_chunk_id": source_chunk_id,
|
|
||||||
}
|
}
|
||||||
# page_content 存 question,用于向量索引
|
cached = question_map[global_idx]
|
||||||
qa_chunks.append(
|
chunks.append(
|
||||||
DocumentChunk(page_content=pair["question"], metadata=qa_meta))
|
DocumentChunk(
|
||||||
qa_sort_id += 1
|
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||||
|
metadata=metadata))
|
||||||
# 按 batch 分组(source + qa 一起)
|
all_batch_chunks.append(chunks)
|
||||||
all_chunks = source_chunks + qa_chunks
|
|
||||||
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
|
|
||||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
|
|
||||||
all_batch_chunks.append(all_chunks[batch_start:batch_end])
|
|
||||||
|
|
||||||
progress_lines.append(
|
|
||||||
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
|
|
||||||
f"{len(qa_chunks)} QA chunks prepared.")
|
|
||||||
else:
|
else:
|
||||||
# 无 auto_questions:直接构建 chunks
|
# 无 auto_questions:直接构建 chunks
|
||||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
@@ -697,136 +635,6 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
|||||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
|
|
||||||
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
|
|
||||||
"""
|
|
||||||
异步导入 QA 问答对(CSV/Excel)
|
|
||||||
|
|
||||||
文件格式:第一行标题(跳过),第一列问题,第二列答案
|
|
||||||
"""
|
|
||||||
import csv as csv_module
|
|
||||||
import io
|
|
||||||
|
|
||||||
db = None
|
|
||||||
try:
|
|
||||||
from app.db import get_db_context
|
|
||||||
with get_db_context() as db:
|
|
||||||
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
|
||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
|
|
||||||
if not db_document or not db_knowledge:
|
|
||||||
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
|
|
||||||
return {"error": "document or knowledge not found", "imported": 0}
|
|
||||||
|
|
||||||
# 1. 解析文件
|
|
||||||
qa_pairs = []
|
|
||||||
failed_rows = []
|
|
||||||
|
|
||||||
if filename.endswith(".csv"):
|
|
||||||
try:
|
|
||||||
text = contents.decode("utf-8-sig")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
text = contents.decode("gbk", errors="ignore")
|
|
||||||
|
|
||||||
sniffer = csv_module.Sniffer()
|
|
||||||
try:
|
|
||||||
dialect = sniffer.sniff(text[:2048])
|
|
||||||
delimiter = dialect.delimiter
|
|
||||||
except csv_module.Error:
|
|
||||||
delimiter = "," if "," in text[:500] else "\t"
|
|
||||||
|
|
||||||
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
|
|
||||||
for i, row in enumerate(reader):
|
|
||||||
if i == 0:
|
|
||||||
continue
|
|
||||||
if len(row) >= 2 and row[0].strip() and row[1].strip():
|
|
||||||
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
|
|
||||||
elif len(row) >= 1 and row[0].strip():
|
|
||||||
failed_rows.append(i + 1)
|
|
||||||
|
|
||||||
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
|
||||||
try:
|
|
||||||
import openpyxl
|
|
||||||
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
|
||||||
for sheet in wb.worksheets:
|
|
||||||
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
|
||||||
if i == 0:
|
|
||||||
continue
|
|
||||||
if len(row) >= 2 and row[0] and row[1]:
|
|
||||||
q = str(row[0]).strip()
|
|
||||||
a = str(row[1]).strip()
|
|
||||||
if q and a:
|
|
||||||
qa_pairs.append({"question": q, "answer": a})
|
|
||||||
elif len(row) >= 1 and row[0]:
|
|
||||||
failed_rows.append(i + 1)
|
|
||||||
wb.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ImportQA] Excel parse failed: {e}")
|
|
||||||
return {"error": f"Excel parse failed: {e}", "imported": 0}
|
|
||||||
|
|
||||||
if not qa_pairs:
|
|
||||||
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
|
|
||||||
return {"error": "No valid QA pairs found", "imported": 0}
|
|
||||||
|
|
||||||
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
|
|
||||||
|
|
||||||
# 2. 写入 ES
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
|
|
||||||
sort_id = 0
|
|
||||||
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
|
|
||||||
if items:
|
|
||||||
sort_id = items[0].metadata["sort_id"]
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for pair in qa_pairs:
|
|
||||||
sort_id += 1
|
|
||||||
doc_id = uuid.uuid4().hex
|
|
||||||
metadata = {
|
|
||||||
"doc_id": doc_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": document_id,
|
|
||||||
"knowledge_id": kb_id,
|
|
||||||
"sort_id": sort_id,
|
|
||||||
"status": 1,
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": pair["question"],
|
|
||||||
"answer": pair["answer"],
|
|
||||||
}
|
|
||||||
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
|
|
||||||
|
|
||||||
batch_size = 50
|
|
||||||
for i in range(0, len(chunks), batch_size):
|
|
||||||
batch = chunks[i:i + batch_size]
|
|
||||||
vector_service.add_chunks(batch)
|
|
||||||
|
|
||||||
# 3. 更新 chunk_num 和 progress
|
|
||||||
db_document.chunk_num += len(chunks)
|
|
||||||
db_document.progress = 1.0
|
|
||||||
db_document.progress_msg = f"QA 导入完成: {len(chunks)} 条"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
result = {"imported": len(chunks), "failed_rows": failed_rows}
|
|
||||||
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
|
|
||||||
# 尝试更新文档状态为失败
|
|
||||||
try:
|
|
||||||
from app.db import get_db_context
|
|
||||||
with get_db_context() as err_db:
|
|
||||||
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
|
||||||
if doc:
|
|
||||||
doc.progress = -1.0
|
|
||||||
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
|
|
||||||
err_db.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {"error": str(e), "imported": 0}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||||
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user