Compare commits
9 Commits
main
...
feat/wxy-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cef33fce0d | ||
|
|
d9f08860bc | ||
|
|
461674c8d8 | ||
|
|
c59e179cc2 | ||
|
|
a5670bfff6 | ||
|
|
4bef9b578b | ||
|
|
c53fcf3981 | ||
|
|
2997558bc8 | ||
|
|
30cdf229de |
@@ -82,19 +82,32 @@ async def get_preview_chunks(
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 6. Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
# 5. Get file content from storage backend
|
||||
if not db_file.file_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
detail="File has no storage key (legacy data not migrated)"
|
||||
)
|
||||
|
||||
from app.services.file_storage_service import FileStorageService
|
||||
import asyncio
|
||||
storage_service = FileStorageService()
|
||||
|
||||
async def _download():
|
||||
return await storage_service.download_file(db_file.file_key)
|
||||
|
||||
try:
|
||||
file_binary = asyncio.run(_download())
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
file_binary = loop.run_until_complete(_download())
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File not found in storage: {e}"
|
||||
)
|
||||
|
||||
# 7. Document parsing & segmentation
|
||||
@@ -104,11 +117,12 @@ async def get_preview_chunks(
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese", # Default to Chinese
|
||||
lang="Chinese",
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
from app.core.rag.app.naive import chunk
|
||||
res = chunk(filename=file_path,
|
||||
res = chunk(filename=db_file.file_name,
|
||||
binary=file_binary,
|
||||
from_page=0,
|
||||
to_page=5,
|
||||
callback=progress_callback,
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.models.user_model import User
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -231,7 +232,8 @@ async def update_document(
|
||||
async def delete_document(
|
||||
document_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Delete document
|
||||
@@ -257,7 +259,7 @@ async def delete_document(
|
||||
db.commit()
|
||||
|
||||
# 3. Delete file
|
||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
||||
|
||||
# 4. Delete vector index
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||
@@ -305,38 +307,25 @@ async def parse_documents(
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
# 3. Get file_key for storage backend
|
||||
if not db_file.file_key:
|
||||
api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
detail="File has no storage key (legacy data not migrated)"
|
||||
)
|
||||
|
||||
# 5. Obtain knowledge base information
|
||||
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||
# 4. Obtain knowledge base information
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||
|
||||
# 6. Task: Document parsing, vectorization, and storage
|
||||
# from app.tasks import parse_document
|
||||
# parse_document(file_path, document_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||
# 5. Dispatch parse task with file_key (not file_path)
|
||||
task = celery_app.send_task(
|
||||
"app.core.rag.tasks.parse_document",
|
||||
args=[db_file.file_key, document_id, db_file.file_name]
|
||||
)
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -19,10 +17,14 @@ from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import file_service, document_service
|
||||
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
|
||||
from app.services.file_storage_service import (
|
||||
FileStorageService,
|
||||
generate_kb_file_key,
|
||||
get_file_storage_service,
|
||||
)
|
||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
@@ -35,67 +37,37 @@ router = APIRouter(
|
||||
async def get_files(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
page: int = Query(1, gt=0),
|
||||
pagesize: int = Query(20, gt=0, le=100),
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query file list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
"""Paged query file list"""
|
||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = [
|
||||
file_model.File.kb_id == kb_id
|
||||
]
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
|
||||
|
||||
filters = [file_model.File.kb_id == kb_id]
|
||||
if parent_id:
|
||||
filters.append(file_model.File.parent_id == parent_id)
|
||||
# Keyword search (fuzzy matching of file name)
|
||||
if keywords:
|
||||
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing file paging query")
|
||||
total, items = file_service.get_files_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
db=db, filters=filters, page=page, pagesize=pagesize,
|
||||
orderby=orderby, desc=desc, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
"page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||
|
||||
@@ -108,23 +80,14 @@ async def create_folder(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new folder
|
||||
"""
|
||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
||||
|
||||
"""Create a new folder"""
|
||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
|
||||
try:
|
||||
api_logger.debug(f"Start creating a folder: {folder_name}")
|
||||
create_folder = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=folder_name,
|
||||
file_ext='folder',
|
||||
file_size=0,
|
||||
create_folder_data = file_schema.FileCreate(
|
||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||
file_name=folder_name, file_ext='folder', file_size=0,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||
db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||
@@ -138,76 +101,58 @@ async def upload_file(
|
||||
parent_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
upload file
|
||||
"""
|
||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
||||
"""Upload file to storage backend"""
|
||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
|
||||
|
||||
# Read the contents of the file
|
||||
contents = await file.read()
|
||||
# Check file size
|
||||
file_size = len(contents)
|
||||
print(f"file size: {file_size} byte")
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
||||
|
||||
# Extract the extension using `os.path.splitext`
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
upload_file = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=file.filename,
|
||||
file_ext=file_extension.lower(),
|
||||
file_size=file_size,
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Create File record
|
||||
upload_file_data = file_schema.FileCreate(
|
||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||
file_name=file.filename, file_ext=file_ext, file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
||||
|
||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||
# Upload to storage backend
|
||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
||||
try:
|
||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Storage upload failed: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(contents)
|
||||
# Save file_key
|
||||
db_file.file_key = file_key
|
||||
db.commit()
|
||||
db.refresh(db_file)
|
||||
|
||||
# Verify whether the file has been saved successfully
|
||||
if not os.path.exists(save_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="File save failed"
|
||||
)
|
||||
# Create document (inherit parser_config from knowledge base)
|
||||
default_parser_config = {
|
||||
"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
|
||||
}
|
||||
try:
|
||||
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if db_knowledge and db_knowledge.parser_config:
|
||||
default_parser_config.update(dict(db_knowledge.parser_config))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create a document
|
||||
create_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
file_id=db_file.id,
|
||||
file_name=db_file.file_name,
|
||||
file_ext=db_file.file_ext,
|
||||
file_size=db_file.file_size,
|
||||
file_meta={},
|
||||
parser_id="naive",
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
}
|
||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
||||
file_meta={}, parser_id="naive", parser_config=default_parser_config
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
|
||||
@@ -221,123 +166,73 @@ async def custom_text(
|
||||
parent_id: uuid.UUID,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
custom text
|
||||
"""
|
||||
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
||||
|
||||
# Check file content size
|
||||
# 将内容编码为字节(UTF-8)
|
||||
"""Custom text upload"""
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The content is empty."
|
||||
)
|
||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
||||
|
||||
upload_file = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=f"{create_data.title}.txt",
|
||||
file_ext=".txt",
|
||||
file_size=file_size,
|
||||
upload_file_data = file_schema.FileCreate(
|
||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
||||
file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
||||
|
||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||
# Upload to storage backend
|
||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
|
||||
try:
|
||||
await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Storage upload failed: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(content_bytes)
|
||||
db_file.file_key = file_key
|
||||
db.commit()
|
||||
db.refresh(db_file)
|
||||
|
||||
# Verify whether the file has been saved successfully
|
||||
if not os.path.exists(save_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="File save failed"
|
||||
)
|
||||
|
||||
# Create a document
|
||||
create_document_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
file_id=db_file.id,
|
||||
file_name=db_file.file_name,
|
||||
file_ext=db_file.file_ext,
|
||||
file_size=db_file.file_size,
|
||||
file_meta={},
|
||||
parser_id="naive",
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
}
|
||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
||||
file_meta={}, parser_id="naive",
|
||||
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
async def get_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Download the file based on the file_id
|
||||
- Query file information from the database
|
||||
- Construct the file path and check if it exists
|
||||
- Return a FileResponse to download the file
|
||||
"""
|
||||
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
||||
|
||||
# 1. Query file information from the database
|
||||
"""Download file by file_id"""
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
|
||||
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
if not db_file.file_key:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
|
||||
|
||||
# 3. Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
try:
|
||||
content = await storage_service.download_file(db_file.file_key)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Storage download failed: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
|
||||
|
||||
# 4.Return FileResponse (automatically handle download)
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=db_file.file_name, # Use original file name
|
||||
media_type="application/octet-stream" # Universal binary stream type
|
||||
import mimetypes
|
||||
media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
|
||||
)
|
||||
|
||||
|
||||
@@ -348,50 +243,22 @@ async def update_file(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update file information (such as file name)
|
||||
- Only specified fields such as file_name are allowed to be modified
|
||||
"""
|
||||
api_logger.debug(f"Query the file to be updated: {file_id}")
|
||||
|
||||
# 1. Check if the file exists
|
||||
"""Update file information (such as file name)"""
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_data.dict(exclude_unset=True).items():
|
||||
if hasattr(db_file, field):
|
||||
old_value = getattr(db_file, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_file, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
setattr(db_file, field, value)
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_file)
|
||||
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File update failed: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
|
||||
|
||||
# 4. Return the updated file
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||
|
||||
|
||||
@@ -399,60 +266,43 @@ async def update_file(
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
||||
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||
"""Delete a file or folder"""
|
||||
api_logger.info(f"Request to delete file: file_id={file_id}")
|
||||
await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
||||
return success(msg="File deleted successfully")
|
||||
|
||||
|
||||
async def _delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
db: Session,
|
||||
current_user: User,
|
||||
storage_service: FileStorageService,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
# 1. Check if the file exists
|
||||
"""Delete a file or folder from storage and database"""
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
|
||||
# 2. Construct physical path
|
||||
file_path = Path(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.id)
|
||||
) if db_file.file_ext == 'folder' else Path(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 3. Delete physical files/folders
|
||||
try:
|
||||
if file_path.exists():
|
||||
if db_file.file_ext == 'folder':
|
||||
shutil.rmtree(file_path) # Recursively delete folders
|
||||
else:
|
||||
file_path.unlink() # Delete a single file
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete physical file/folder: {str(e)}"
|
||||
)
|
||||
|
||||
# 4.Delete db_file
|
||||
# Delete from storage backend
|
||||
if db_file.file_ext == 'folder':
|
||||
# For folders, delete all child files from storage first
|
||||
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
|
||||
for child in child_files:
|
||||
if child.file_key:
|
||||
try:
|
||||
await storage_service.delete_file(child.file_key)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
|
||||
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||
else:
|
||||
if db_file.file_key:
|
||||
try:
|
||||
await storage_service.delete_file(db_file.file_key)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
|
||||
|
||||
db.delete(db_file)
|
||||
db.commit()
|
||||
|
||||
@@ -14,6 +14,7 @@ Transcribe the content from the provided PDF page image into clean Markdown form
|
||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
|
||||
|
||||
{% if page %}
|
||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
@@ -141,9 +142,10 @@ class GraphBuilder:
|
||||
|
||||
for node_info in source_nodes:
|
||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
if node_info.get("branch") is not None:
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||
output_nodes.append(node_info["id"])
|
||||
@@ -314,9 +316,12 @@ class GraphBuilder:
|
||||
for idx in range(len(related_edge)):
|
||||
# Generate a condition expression for each edge
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
|
||||
# For other branch nodes (e.g. HTTP), use output field
|
||||
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
|
||||
related_edge[idx]['condition'] = (
|
||||
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
|
||||
)
|
||||
|
||||
if node_instance:
|
||||
# Wrap node's run method to avoid closure issues
|
||||
|
||||
@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
self._input_data: dict[str, Any] | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
|
||||
if self._input_data is not None:
|
||||
return self._input_data
|
||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
|
||||
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 匹配模板变量 {{xxx}} 的正则
|
||||
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class NodeExecutionError(Exception):
|
||||
"""节点执行失败异常。
|
||||
@@ -503,10 +507,29 @@ class BaseNode(ABC):
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the node's input data.
|
||||
A dictionary containing the node's input data with all template
|
||||
variables resolved to their actual runtime values.
|
||||
"""
|
||||
# Default implementation returns the node configuration
|
||||
return {"config": self.config}
|
||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
||||
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
||||
|
||||
Args:
|
||||
config: 节点的原始配置(可能包含模板变量)。
|
||||
variable_pool: 变量池,用于解析模板变量。
|
||||
|
||||
Returns:
|
||||
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
||||
"""
|
||||
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
||||
return BaseNode._render_template(config, variable_pool, strict=False)
|
||||
elif isinstance(config, dict):
|
||||
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
||||
return config
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""Extracts the actual output from the business result.
|
||||
|
||||
@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
|
||||
return business_result
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {"file_selector": self.config.get("file_selector")}
|
||||
file_selector = self.config.get("file_selector", "")
|
||||
# 将变量选择器(如 sys.files)解析为实际值
|
||||
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
||||
return {"file_selector": resolved}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
config = DocExtractorNodeConfig(**self.config)
|
||||
|
||||
@@ -31,7 +31,7 @@ class NodeType(StrEnum):
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -6,6 +6,7 @@ import uuid
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandleConfig(BaseModel):
|
||||
"""LLM 异常处理配置"""
|
||||
|
||||
method: HttpErrorHandle = Field(
|
||||
default=HttpErrorHandle.NONE,
|
||||
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="",
|
||||
description="LLM 异常时返回的默认输出文本(method=default 时生效)",
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
error_handle: LLMErrorHandleConfig = Field(
|
||||
default_factory=LLMErrorHandleConfig,
|
||||
description="LLM 异常处理配置",
|
||||
)
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v):
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
|
||||
self.messages = []
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
|
||||
|
||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
|
||||
{"llm_result": None, "branch_signal": "ERROR"} on branch error
|
||||
"""
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
try:
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return AIMessage(content=content, response_metadata={
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
})
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return {
|
||||
"llm_result": AIMessage(content=content, response_metadata={
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
}),
|
||||
"branch_signal": "SUCCESS",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
|
||||
return self._handle_llm_error(e)
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
def _extract_output(self, business_result: Any) -> dict:
|
||||
"""从业务结果中提取输出变量
|
||||
|
||||
支持新旧两种格式:
|
||||
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
|
||||
- 旧格式:AIMessage(向后兼容)
|
||||
"""
|
||||
if isinstance(business_result, dict) and "branch_signal" in business_result:
|
||||
llm_result = business_result.get("llm_result")
|
||||
if isinstance(llm_result, AIMessage):
|
||||
return {
|
||||
"output": llm_result.content,
|
||||
"branch_signal": business_result["branch_signal"],
|
||||
}
|
||||
return {
|
||||
"output": str(llm_result) if llm_result else "",
|
||||
"branch_signal": business_result["branch_signal"],
|
||||
}
|
||||
# 旧格式向后兼容
|
||||
if isinstance(business_result, AIMessage):
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
return {"output": business_result.content, "branch_signal": "SUCCESS"}
|
||||
return {"output": str(business_result), "branch_signal": "SUCCESS"}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
"""从业务结果中提取 token 使用情况"""
|
||||
llm_result = business_result
|
||||
if isinstance(business_result, dict):
|
||||
llm_result = business_result.get("llm_result", business_result)
|
||||
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
|
||||
usage = llm_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('input_tokens', 0),
|
||||
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _handle_llm_error(self, error: Exception) -> dict:
|
||||
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
|
||||
|
||||
Args:
|
||||
error: LLM 调用中捕获的异常
|
||||
|
||||
Returns:
|
||||
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
|
||||
or default output for default mode
|
||||
|
||||
Raises:
|
||||
原异常(当 error_handle.method 为 NONE 时)
|
||||
"""
|
||||
if self.typed_config is None:
|
||||
raise error
|
||||
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
raise error
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
|
||||
)
|
||||
default_output = self.typed_config.error_handle.output or ""
|
||||
return {
|
||||
"llm_result": AIMessage(content=default_output, response_metadata={}),
|
||||
"branch_signal": "SUCCESS",
|
||||
}
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
|
||||
)
|
||||
return {
|
||||
"llm_result": None,
|
||||
"branch_signal": "ERROR",
|
||||
}
|
||||
raise error
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
try:
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
last_usage_metadata = {}
|
||||
async for chunk in llm.astream(self.messages):
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
last_usage_metadata = chunk.usage_metadata
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
last_usage_metadata = {}
|
||||
async for chunk in llm.astream(self.messages):
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
last_usage_metadata = chunk.usage_metadata
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
# 流式返回每个文本片段
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata={
|
||||
**last_meta_data,
|
||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
)
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata={
|
||||
**last_meta_data,
|
||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||
}
|
||||
)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
|
||||
except Exception as e:
|
||||
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
|
||||
error_result = self._handle_llm_error(e)
|
||||
yield {"__final__": True, "result": error_result}
|
||||
|
||||
@@ -15,4 +15,5 @@ class File(Base):
|
||||
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
||||
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
@@ -11,6 +11,7 @@ class FileBase(BaseModel):
|
||||
file_ext: str
|
||||
file_size: int
|
||||
file_url: str | None = None
|
||||
file_key: str | None = None
|
||||
created_at: datetime.datetime | None = None
|
||||
|
||||
|
||||
|
||||
@@ -102,6 +102,11 @@ class AppDslService:
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||
]
|
||||
return enriched
|
||||
if app_type == AppType.WORKFLOW:
|
||||
enriched = {**cfg}
|
||||
if "nodes" in cfg:
|
||||
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
||||
return enriched
|
||||
return cfg
|
||||
|
||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
@@ -110,7 +115,7 @@ class AppDslService:
|
||||
config_data = {
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
||||
"features": config.features if config else {},
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
@@ -190,6 +195,23 @@ class AppDslService:
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
||||
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
enriched_nodes = []
|
||||
for node in (nodes or []):
|
||||
node_type = node.get("type")
|
||||
config = dict(node.get("config") or {})
|
||||
|
||||
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||
model_id = config.get("model_id")
|
||||
if model_id:
|
||||
config["model_ref"] = self._model_ref(model_id)
|
||||
del config["model_id"]
|
||||
|
||||
enriched_nodes.append({**node, "config": config})
|
||||
return enriched_nodes
|
||||
|
||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||
if not skill_id:
|
||||
return None
|
||||
@@ -620,16 +642,16 @@ class AppDslService:
|
||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||
config["knowledge_bases"] = resolved_kbs
|
||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||
model_ref = config.get("model_id")
|
||||
model_ref = config.get("model_ref") or config.get("model_id")
|
||||
if model_ref:
|
||||
ref_dict = None
|
||||
if isinstance(model_ref, dict):
|
||||
ref_id = model_ref.get("id")
|
||||
ref_name = model_ref.get("name")
|
||||
if ref_id:
|
||||
ref_dict = {"id": ref_id}
|
||||
elif ref_name is not None:
|
||||
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
||||
ref_dict = {
|
||||
"id": model_ref.get("id"),
|
||||
"name": model_ref.get("name"),
|
||||
"provider": model_ref.get("provider"),
|
||||
"type": model_ref.get("type")
|
||||
}
|
||||
elif isinstance(model_ref, str):
|
||||
try:
|
||||
uuid.UUID(model_ref)
|
||||
@@ -640,12 +662,18 @@ class AppDslService:
|
||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||
if resolved_model_id:
|
||||
config["model_id"] = resolved_model_id
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
else:
|
||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||
config["model_id"] = None
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
else:
|
||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||
config["model_id"] = None
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
resolved_nodes.append({**node, "config": config})
|
||||
return resolved_nodes
|
||||
|
||||
|
||||
@@ -34,26 +34,7 @@ def generate_file_key(
|
||||
Generate a unique file key for storage.
|
||||
|
||||
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID.
|
||||
workspace_id: The workspace UUID.
|
||||
file_id: The file UUID.
|
||||
file_ext: The file extension (e.g., '.pdf', '.txt').
|
||||
|
||||
Returns:
|
||||
A unique file key string.
|
||||
|
||||
Example:
|
||||
>>> generate_file_key(
|
||||
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
|
||||
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
|
||||
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
|
||||
... '.pdf'
|
||||
... )
|
||||
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
|
||||
"""
|
||||
# Ensure file_ext starts with a dot
|
||||
if file_ext and not file_ext.startswith('.'):
|
||||
file_ext = f'.{file_ext}'
|
||||
if workspace_id:
|
||||
@@ -61,6 +42,21 @@ def generate_file_key(
|
||||
return f"{tenant_id}/{file_id}{file_ext}"
|
||||
|
||||
|
||||
def generate_kb_file_key(
|
||||
kb_id: uuid.UUID,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a file key for knowledge base files.
|
||||
|
||||
Format: kb/{kb_id}/{file_id}{file_ext}
|
||||
"""
|
||||
if file_ext and not file_ext.startswith('.'):
|
||||
file_ext = f'.{file_ext}'
|
||||
return f"kb/{kb_id}/{file_id}{file_ext}"
|
||||
|
||||
|
||||
class FileStorageService:
|
||||
"""
|
||||
High-level service for file storage operations.
|
||||
|
||||
@@ -210,9 +210,14 @@ def _build_vision_model(file_path: str, db_knowledge):
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
||||
def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
"""
|
||||
Document parsing, vectorization, and storage
|
||||
Document parsing, vectorization, and storage.
|
||||
|
||||
Args:
|
||||
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
|
||||
document_id: Document UUID
|
||||
file_name: Original file name (used for extension detection in chunk())
|
||||
"""
|
||||
|
||||
db_document = None
|
||||
@@ -223,7 +228,6 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
||||
if not isinstance(document_id, uuid.UUID):
|
||||
document_id = uuid.UUID(str(document_id))
|
||||
|
||||
@@ -234,7 +238,11 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
if db_knowledge is None:
|
||||
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
||||
|
||||
# 1. Document parsing & segmentation
|
||||
# Use file_name from argument or fall back to document record
|
||||
if not file_name:
|
||||
file_name = db_document.file_name
|
||||
|
||||
# 1. Download file from storage backend
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
||||
start_time = time.time()
|
||||
db_document.progress = 0.0
|
||||
@@ -245,45 +253,36 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# Read file content from storage backend (no NFS dependency)
|
||||
from app.services.file_storage_service import FileStorageService
|
||||
import asyncio
|
||||
storage_service = FileStorageService()
|
||||
|
||||
async def _download():
|
||||
return await storage_service.download_file(file_key)
|
||||
|
||||
try:
|
||||
file_binary = asyncio.run(_download())
|
||||
except RuntimeError:
|
||||
# If there's already a running loop (e.g. in some worker configurations)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
file_binary = loop.run_until_complete(_download())
|
||||
finally:
|
||||
loop.close()
|
||||
if not file_binary:
|
||||
raise IOError(f"Downloaded empty file from storage: {file_key}")
|
||||
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
|
||||
|
||||
def progress_callback(prog=None, msg=None):
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
||||
|
||||
# Prepare vision_model for parsing
|
||||
vision_model = _build_vision_model(file_path, db_knowledge)
|
||||
|
||||
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
|
||||
# python-docx 等库在 binary=None 时会用路径直接打开文件,
|
||||
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
|
||||
max_wait_seconds = 30
|
||||
wait_interval = 2
|
||||
waited = 0
|
||||
file_binary = None
|
||||
while waited <= max_wait_seconds:
|
||||
# os.listdir 强制 NFS 客户端刷新目录缓存
|
||||
parent_dir = os.path.dirname(file_path)
|
||||
try:
|
||||
os.listdir(parent_dir)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_binary = f.read()
|
||||
if not file_binary:
|
||||
# NFS 上文件存在但内容为空(可能还在同步中)
|
||||
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
|
||||
break
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
if waited >= max_wait_seconds:
|
||||
raise type(e)(
|
||||
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
|
||||
)
|
||||
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
vision_model = _build_vision_model(file_name, db_knowledge)
|
||||
|
||||
from app.core.rag.app.naive import chunk
|
||||
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
||||
res = chunk(filename=file_path,
|
||||
res = chunk(filename=file_name,
|
||||
binary=file_binary,
|
||||
from_page=0,
|
||||
to_page=DEFAULT_PARSE_TO_PAGE,
|
||||
|
||||
Reference in New Issue
Block a user