Compare commits

...

9 Commits

Author SHA1 Message Date
wxy
cef33fce0d fix(workflow): sanitize condition expression building and cache assigner node inputs
- Sanitize condition expression construction in graph_builder.py using json.dumps to prevent potential injection vulnerabilities.
- Cache input data prior to assigner node execution to ensure variable values are correctly captured before processing.
2026-05-07 16:26:47 +08:00
wxy
d9f08860bc feat(LLM node): integrate exception handling and enable branch routing
- Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch.
- Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure.
- Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment.
- Implement backward compatibility to support both legacy and new result formats.
2026-05-07 11:43:24 +08:00
wxy
461674c8d8 feat(workflow): parse and substitute template variables in node configurations
- Implement regex matching for {{xxx}} template variable format.
- Enable recursive parsing of all string template variables within node configurations.
- Resolve and substitute template variables with runtime values during input data extraction.
- Support dynamic parsing and substitution of file selector variables in the document extraction node.
- Make strict template variable mode optional and introduce support for default values.
2026-04-29 14:10:02 +08:00
wxy
c59e179cc2 feat(workflow): incorporate model references and streamline parsing logic
- Incorporate model reference metadata (name, provider, type) into workflow nodes and refactor parsing logic to support the new format.
- Streamline code structure by removing redundant model_id fields to enhance maintainability.
2026-04-28 11:18:06 +08:00
Mark
a5670bfff6 Merge branch 'feature/rag2' into develop 2026-04-27 18:17:49 +08:00
Mark
4bef9b578b [fix] document file delete 2026-04-27 17:35:13 +08:00
Mark
c53fcf3981 [fix] old code file_path 2026-04-27 17:10:00 +08:00
Mark
2997558bc8 Merge branch 'release/v0.3.2' into feature/rag2
* release/v0.3.2: (245 commits)
  fix(conversation_schema): refine citations field type to Dict[str, Any]
  fix(tool_controller): re-raise HTTPException to preserve original status codes
  fix(workflow): add reasoning content, suggested questions, citations and audio status support
  feat(workflow): augment logging queries and ameliorate error handling
  fix(api_key): bypass publication check for SERVICE type API keys
  fix(multimodal_service): add '文档内容:' prefix to document text and simplify image placeholder text
  fix(api): convert config_id to string in write_router
  fix(api): convert end_user_id to string in write_router
  fix(multimodal_service): refactor image processing to use intermediate list before extending result
  fix(web): node status ui
  fix(api): correct import paths in memory_read and celery task command
  fix(api): correct import paths in memory_read and celery task command
  refactor(tool): flatten request body parameters for model exposure
  fix(api): correct import paths in memory_read and celery task command
  refactor(workflow): streamline node execution handling and log service logic
  feat(web): http request add process
  feat(web): workflow app logs
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  refactor(http_request): simplify request handling and remove unused fields
  ...

# Conflicts:
#	api/app/controllers/file_controller.py
#	api/app/tasks.py
2026-04-27 16:13:57 +08:00
Mark
30cdf229de [modify] rag file system 2026-04-27 16:05:27 +08:00
16 changed files with 479 additions and 468 deletions

View File

@@ -82,19 +82,32 @@ async def get_preview_chunks(
detail="The file does not exist or you do not have permission to access it" detail="The file does not exist or you do not have permission to access it"
) )
# 5. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} # 5. Get file content from storage backend
file_path = os.path.join( if not db_file.file_key:
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 6. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)" detail="File has no storage key (legacy data not migrated)"
)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(db_file.file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File not found in storage: {e}"
) )
# 7. Document parsing & segmentation # 7. Document parsing & segmentation
@@ -104,11 +117,12 @@ async def get_preview_chunks(
vision_model = QWenCV( vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key, key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name, model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese", # Default to Chinese lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base base_url=db_knowledge.image2text.api_keys[0].api_base
) )
from app.core.rag.app.naive import chunk from app.core.rag.app.naive import chunk
res = chunk(filename=file_path, res = chunk(filename=db_file.file_name,
binary=file_binary,
from_page=0, from_page=0,
to_page=5, to_page=5,
callback=progress_callback, callback=progress_callback,

View File

@@ -20,6 +20,7 @@ from app.models.user_model import User
from app.schemas import document_schema from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import document_service, file_service, knowledge_service from app.services import document_service, file_service, knowledge_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service
# Obtain a dedicated API logger # Obtain a dedicated API logger
@@ -231,7 +232,8 @@ async def update_document(
async def delete_document( async def delete_document(
document_id: uuid.UUID, document_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """
Delete document Delete document
@@ -257,7 +259,7 @@ async def delete_document(
db.commit() db.commit()
# 3. Delete file # 3. Delete file
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user) await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
# 4. Delete vector index # 4. Delete vector index
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
@@ -305,38 +307,25 @@ async def parse_documents(
detail="The file does not exist or you do not have permission to access it" detail="The file does not exist or you do not have permission to access it"
) )
# 3. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} # 3. Get file_key for storage backend
file_path = os.path.join( if not db_file.file_key:
settings.FILE_PATH, api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path):
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)" detail="File has no storage key (legacy data not migrated)"
) )
# 5. Obtain knowledge base information # 4. Obtain knowledge base information
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}") api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
if not db_knowledge: if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 6. Task: Document parsing, vectorization, and storage # 5. Dispatch parse task with file_key (not file_path)
# from app.tasks import parse_document task = celery_app.send_task(
# parse_document(file_path, document_id) "app.core.rag.tasks.parse_document",
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id]) args=[db_file.file_key, document_id, db_file.file_name]
)
result = { result = {
"task_id": task.id "task_id": task.id
} }

View File

@@ -1,12 +1,10 @@
import os import os
from pathlib import Path
import shutil
from typing import Any, Optional from typing import Any, Optional
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse from fastapi.responses import Response
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
@@ -19,10 +17,14 @@ from app.models.user_model import User
from app.schemas import file_schema, document_schema from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import file_service, document_service from app.services import file_service, document_service
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
from app.services.file_storage_service import (
FileStorageService,
generate_kb_file_key,
get_file_storage_service,
)
from app.core.quota_stub import check_knowledge_capacity_quota from app.core.quota_stub import check_knowledge_capacity_quota
# Obtain a dedicated API logger
api_logger = get_api_logger() api_logger = get_api_logger()
router = APIRouter( router = APIRouter(
@@ -35,67 +37,37 @@ router = APIRouter(
async def get_files( async def get_files(
kb_id: uuid.UUID, kb_id: uuid.UUID,
parent_id: uuid.UUID, parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 page: int = Query(1, gt=0),
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items pagesize: int = Query(20, gt=0, le=100),
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"), orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"), desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"), keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """Paged query file list"""
Paged query file list api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions if page < 1 or pagesize < 1:
filters = [ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
file_model.File.kb_id == kb_id
] filters = [file_model.File.kb_id == kb_id]
if parent_id: if parent_id:
filters.append(file_model.File.parent_id == parent_id) filters.append(file_model.File.parent_id == parent_id)
# Keyword search (fuzzy matching of file name)
if keywords: if keywords:
filters.append(file_model.File.file_name.ilike(f"%{keywords}%")) filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
# 3. Execute paged query
try: try:
api_logger.debug("Start executing file paging query")
total, items = file_service.get_files_paginated( total, items = file_service.get_files_paginated(
db=db, db=db, filters=filters, page=page, pagesize=pagesize,
filters=filters, orderby=orderby, desc=desc, current_user=current_user
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
) )
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = { result = {
"items": items, "items": items,
"page": { "page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
} }
return success(data=jsonable_encoder(result), msg="Query of file list succeeded") return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
@@ -108,23 +80,14 @@ async def create_folder(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """Create a new folder"""
Create a new folder api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
try: try:
api_logger.debug(f"Start creating a folder: {folder_name}") create_folder_data = file_schema.FileCreate(
create_folder = file_schema.FileCreate( kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
kb_id=kb_id, file_name=folder_name, file_ext='folder', file_size=0,
created_by=current_user.id,
parent_id=parent_id,
file_name=folder_name,
file_ext='folder',
file_size=0,
) )
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user) db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful") return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
except Exception as e: except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}") api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
@@ -138,76 +101,58 @@ async def upload_file(
parent_id: uuid.UUID, parent_id: uuid.UUID,
file: UploadFile = File(...), file: UploadFile = File(...),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Upload file to storage backend"""
upload file api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
# Read the contents of the file
contents = await file.read() contents = await file.read()
# Check file size
file_size = len(contents) file_size = len(contents)
print(f"file size: {file_size} byte")
if file_size == 0: if file_size == 0:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE: if file_size > settings.MAX_FILE_SIZE:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
# Extract the extension using `os.path.splitext`
_, file_extension = os.path.splitext(file.filename) _, file_extension = os.path.splitext(file.filename)
upload_file = file_schema.FileCreate( file_ext = file_extension.lower()
kb_id=kb_id,
created_by=current_user.id, # Create File record
parent_id=parent_id, upload_file_data = file_schema.FileCreate(
file_name=file.filename, kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
file_ext=file_extension.lower(), file_name=file.filename, file_ext=file_ext, file_size=file_size,
file_size=file_size,
) )
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension} # Upload to storage backend
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists try:
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file # Save file_key
with open(save_path, "wb") as f: db_file.file_key = file_key
f.write(contents) db.commit()
db.refresh(db_file)
# Verify whether the file has been saved successfully # Create document (inherit parser_config from knowledge base)
if not os.path.exists(save_path): default_parser_config = {
raise HTTPException( "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
detail="File save failed" }
) try:
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
if db_knowledge and db_knowledge.parser_config:
default_parser_config.update(dict(db_knowledge.parser_config))
except Exception:
pass
# Create a document
create_data = document_schema.DocumentCreate( create_data = document_schema.DocumentCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
created_by=current_user.id, file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_id=db_file.id, file_meta={}, parser_id="naive", parser_config=default_parser_config
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
) )
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user) db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
@@ -221,123 +166,73 @@ async def custom_text(
parent_id: uuid.UUID, parent_id: uuid.UUID,
create_data: file_schema.CustomTextFileCreate, create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Custom text upload"""
custom text
"""
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
# Check file content size
# 将内容编码为字节UTF-8
content_bytes = create_data.content.encode('utf-8') content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes) file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0: if file_size == 0:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE: if file_size > settings.MAX_FILE_SIZE:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
upload_file = file_schema.FileCreate( upload_file_data = file_schema.FileCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
created_by=current_user.id, file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
) )
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension} # Upload to storage backend
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists try:
save_path = os.path.join(save_dir, f"{db_file.id}.txt") await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file db_file.file_key = file_key
with open(save_path, "wb") as f: db.commit()
f.write(content_bytes) db.refresh(db_file)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate( create_document_data = document_schema.DocumentCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
created_by=current_user.id, file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_id=db_file.id, file_meta={}, parser_id="naive",
file_name=db_file.file_name, parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
file_ext=db_file.file_ext, "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
) )
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful") return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any) @router.get("/{file_id}", response_model=Any)
async def get_file( async def get_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db) db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any: ) -> Any:
""" """Download file by file_id"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
# 1. Query file information from the database
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} if not db_file.file_key:
file_path = os.path.join( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Check if the file exists try:
if not os.path.exists(file_path): content = await storage_service.download_file(db_file.file_key)
raise HTTPException( except Exception as e:
status_code=status.HTTP_404_NOT_FOUND, api_logger.error(f"Storage download failed: {e}")
detail="File not found (possibly deleted)" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
)
# 4.Return FileResponse (automatically handle download) import mimetypes
return FileResponse( media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
path=file_path, return Response(
filename=db_file.file_name, # Use original file name content=content,
media_type="application/octet-stream" # Universal binary stream type media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
) )
@@ -348,50 +243,22 @@ async def update_file(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """Update file information (such as file name)"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
api_logger.debug(f"Query the file to be updated: {file_id}")
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items(): for field, value in update_data.dict(exclude_unset=True).items():
if hasattr(db_file, field): if hasattr(db_file, field):
old_value = getattr(db_file, field) setattr(db_file, field, value)
if old_value != value:
# update value
setattr(db_file, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try: try:
db.commit() db.commit()
db.refresh(db_file) db.refresh(db_file)
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
except Exception as e: except Exception as e:
db.rollback() db.rollback()
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File update failed: {str(e)}"
)
# 4. Return the updated file
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully") return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
@@ -399,60 +266,43 @@ async def update_file(
async def delete_file( async def delete_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Delete a file or folder"""
Delete a file or folder api_logger.info(f"Request to delete file: file_id={file_id}")
""" await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
await _delete_file(db=db, file_id=file_id, current_user=current_user)
return success(msg="File deleted successfully") return success(msg="File deleted successfully")
async def _delete_file( async def _delete_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session,
current_user: User = Depends(get_current_user) current_user: User,
storage_service: FileStorageService,
) -> None: ) -> None:
""" """Delete a file or folder from storage and database"""
Delete a file or folder
"""
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct physical path # Delete from storage backend
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.id)
) if db_file.file_ext == 'folder' else Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Delete physical files/folders
try:
if file_path.exists():
if db_file.file_ext == 'folder':
shutil.rmtree(file_path) # Recursively delete folders
else:
file_path.unlink() # Delete a single file
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete physical file/folder: {str(e)}"
)
# 4.Delete db_file
if db_file.file_ext == 'folder': if db_file.file_ext == 'folder':
# For folders, delete all child files from storage first
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
for child in child_files:
if child.file_key:
try:
await storage_service.delete_file(child.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete() db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
else:
if db_file.file_key:
try:
await storage_service.delete_file(db_file.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
db.delete(db_file) db.delete(db_file)
db.commit() db.commit()

View File

@@ -14,6 +14,7 @@ Transcribe the content from the provided PDF page image into clean Markdown form
6. Do NOT wrap the output in ```markdown or ``` blocks. 6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image. 7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image. 8. Preserve the original language, information, and order exactly as shown in the image.
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
{% if page %} {% if page %}
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`. At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
import json
import logging import logging
import re import re
import uuid import uuid
@@ -141,9 +142,10 @@ class GraphBuilder:
for node_info in source_nodes: for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES: if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append( if node_info.get("branch") is not None:
(node_info["id"], node_info["branch"]) branch_nodes.append(
) (node_info["id"], node_info["branch"])
)
else: else:
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT): if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
output_nodes.append(node_info["id"]) output_nodes.append(node_info["id"])
@@ -314,9 +316,12 @@ class GraphBuilder:
for idx in range(len(related_edge)): for idx in range(len(related_edge)):
# Generate a condition expression for each edge # Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output # Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label # For LLM nodes, use branch_signal field for routing (output is dynamic text)
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' # For other branch nodes (e.g. HTTP), use output field
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" route_field = "branch_signal" if node_type == NodeType.LLM else "output"
related_edge[idx]['condition'] = (
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
)
if node_instance: if node_instance:
# Wrap node's run method to avoid closure issues # Wrap node's run method to avoid closure issues

View File

@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None
self._input_data: dict[str, Any] | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {} return {}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
if self._input_data is not None:
return self._input_data
return {"config": self._resolve_config(self.config, variable_pool)}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
Execute the assignment operation defined by this node. Execute the assignment operation defined by this node.
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
Returns: Returns:
None or the result of the assignment operation. None or the result of the assignment operation.
""" """
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
# Initialize a variable pool for accessing conversation, node, and system variables # Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config) self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行") logger.info(f"节点 {self.node_id} 开始执行")

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import re
import time import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 匹配模板变量 {{xxx}} 的正则
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class NodeExecutionError(Exception): class NodeExecutionError(Exception):
"""节点执行失败异常。 """节点执行失败异常。
@@ -503,10 +507,29 @@ class BaseNode(ABC):
variable_pool: The variable pool used for reading and writing variables. variable_pool: The variable pool used for reading and writing variables.
Returns: Returns:
A dictionary containing the node's input data. A dictionary containing the node's input data with all template
variables resolved to their actual runtime values.
""" """
# Default implementation returns the node configuration return {"config": self._resolve_config(self.config, variable_pool)}
return {"config": self.config}
@staticmethod
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
Args:
config: 节点的原始配置(可能包含模板变量)。
variable_pool: 变量池,用于解析模板变量。
Returns:
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
"""
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
return BaseNode._render_template(config, variable_pool, strict=False)
elif isinstance(config, dict):
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
elif isinstance(config, list):
return [BaseNode._resolve_config(item, variable_pool) for item in config]
return config
def _extract_output(self, business_result: Any) -> Any: def _extract_output(self, business_result: Any) -> Any:
"""Extracts the actual output from the business result. """Extracts the actual output from the business result.

View File

@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
return business_result return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")} file_selector = self.config.get("file_selector", "")
# 将变量选择器(如 sys.files解析为实际值
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
return {"file_selector": resolved}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config) config = DocExtractorNodeConfig(**self.config)

View File

@@ -31,7 +31,7 @@ class NodeType(StrEnum):
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -6,6 +6,7 @@ import uuid
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
) )
class LLMErrorHandleConfig(BaseModel):
"""LLM 异常处理配置"""
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
)
output: str = Field(
default="",
description="LLM 异常时返回的默认输出文本method=default 时生效)",
)
class LLMNodeConfig(BaseNodeConfig): class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置 """LLM 节点配置
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="输出变量定义(自动生成,通常不需要修改)" description="输出变量定义(自动生成,通常不需要修改)"
) )
error_handle: LLMErrorHandleConfig = Field(
default_factory=LLMErrorHandleConfig,
description="LLM 异常处理配置",
)
@field_validator("messages", "prompt") @field_validator("messages", "prompt")
@classmethod @classmethod
def validate_input_mode(cls, v): def validate_input_mode(cls, v):

View File

@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context from app.db import get_db_context
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
self.messages = [] self.messages = []
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
def _render_context(self, message: str, variable_pool: VariablePool): def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>" context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
return llm return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: async def execute(self, state: WorkflowState, variable_pool: VariablePool):
"""非流式执行 LLM 调用 """非流式执行 LLM 调用
Args: Args:
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
variable_pool: 变量池 variable_pool: 变量池
Returns: Returns:
LLM 响应消息 dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
{"llm_result": None, "branch_signal": "ERROR"} on branch error
""" """
# self.typed_config = LLMNodeConfig(**self.config) try:
llm = await self._prepare_llm(state, variable_pool, False) # self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表 # 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages) response = await llm.ainvoke(self.messages)
# 提取内容 # 提取内容
if hasattr(response, 'content'): if hasattr(response, 'content'):
content = self.process_model_output(response.content) content = self.process_model_output(response.content)
else: else:
content = str(response) content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据 # 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={ return {
**response.response_metadata, "llm_result": AIMessage(content=content, response_metadata={
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') **response.response_metadata,
}) "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}),
"branch_signal": "SUCCESS",
}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
return self._handle_llm_error(e)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)""" """提取输入数据(用于记录)"""
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
} }
} }
def _extract_output(self, business_result: Any) -> str: def _extract_output(self, business_result: Any) -> dict:
""" AIMessage 中提取文本内容""" """业务结果中提取输出变量
支持新旧两种格式:
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
- 旧格式AIMessage向后兼容
"""
if isinstance(business_result, dict) and "branch_signal" in business_result:
llm_result = business_result.get("llm_result")
if isinstance(llm_result, AIMessage):
return {
"output": llm_result.content,
"branch_signal": business_result["branch_signal"],
}
return {
"output": str(llm_result) if llm_result else "",
"branch_signal": business_result["branch_signal"],
}
# 旧格式向后兼容
if isinstance(business_result, AIMessage): if isinstance(business_result, AIMessage):
return business_result.content return {"output": business_result.content, "branch_signal": "SUCCESS"}
return str(business_result) return {"output": str(business_result), "branch_signal": "SUCCESS"}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
""" AIMessage 中提取 token 使用情况""" """业务结果中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): llm_result = business_result
usage = business_result.response_metadata.get('token_usage') if isinstance(business_result, dict):
llm_result = business_result.get("llm_result", business_result)
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
usage = llm_result.response_metadata.get('token_usage')
if usage: if usage:
return { return {
"prompt_tokens": usage.get('input_tokens', 0), "prompt_tokens": usage.get('input_tokens', 0),
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
} }
return None return None
def _handle_llm_error(self, error: Exception) -> dict:
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
Args:
error: LLM 调用中捕获的异常
Returns:
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
or default output for default mode
Raises:
原异常(当 error_handle.method 为 NONE 时)
"""
if self.typed_config is None:
raise error
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
raise error
case HttpErrorHandle.DEFAULT:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
)
default_output = self.typed_config.error_handle.output or ""
return {
"llm_result": AIMessage(content=default_output, response_metadata={}),
"branch_signal": "SUCCESS",
}
case HttpErrorHandle.BRANCH:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
)
return {
"llm_result": None,
"branch_signal": "ERROR",
}
raise error
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用 """流式执行 LLM 调用
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
""" """
self.typed_config = LLMNodeConfig(**self.config) self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, True) try:
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
chunk_count = 0 chunk_count = 0
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
last_usage_metadata = {} last_usage_metadata = {}
async for chunk in llm.astream(self.messages): async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content) content = self.process_model_output(chunk.content)
else: else:
content = str(chunk) content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata: if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理 # 只有当内容不为空时才处理
if content: if content:
full_response += content full_response += content
chunk_count += 1 chunk_count += 1
# 流式返回每个文本片段 # 流式返回每个文本片段
yield { yield {
"__final__": False, "__final__": False,
"chunk": content "chunk": content
} }
yield { yield {
"__final__": False, "__final__": False,
"chunk": "", "chunk": "",
"done": True "done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
} }
) logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# yield 完成标记 # 构建完整的 AIMessage包含元数据
yield {"__final__": True, "result": final_message} final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
error_result = self._handle_llm_error(e)
yield {"__final__": True, "result": error_result}

View File

@@ -15,4 +15,5 @@ class File(Base):
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
file_size = Column(Integer, default=0, comment="file size(byte)") file_size = Column(Integer, default=0, comment="file size(byte)")
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
created_at = Column(DateTime, default=datetime.datetime.now) created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -11,6 +11,7 @@ class FileBase(BaseModel):
file_ext: str file_ext: str
file_size: int file_size: int
file_url: str | None = None file_url: str | None = None
file_key: str | None = None
created_at: datetime.datetime | None = None created_at: datetime.datetime | None = None

View File

@@ -102,6 +102,11 @@ class AppDslService:
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or []) {**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
] ]
return enriched return enriched
if app_type == AppType.WORKFLOW:
enriched = {**cfg}
if "nodes" in cfg:
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
return enriched
return cfg return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]: def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
@@ -110,7 +115,7 @@ class AppDslService:
config_data = { config_data = {
"variables": config.variables if config else [], "variables": config.variables if config else [],
"edges": config.edges if config else [], "edges": config.edges if config else [],
"nodes": config.nodes if config else [], "nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
"features": config.features if config else {}, "features": config.features if config else {},
"execution_config": config.execution_config if config else {}, "execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [], "triggers": config.triggers if config else [],
@@ -190,6 +195,23 @@ class AppDslService:
def _enrich_tools(self, tools: list) -> list: def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])] return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _enrich_workflow_nodes(self, nodes: list) -> list:
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
from app.core.workflow.nodes.enums import NodeType
enriched_nodes = []
for node in (nodes or []):
node_type = node.get("type")
config = dict(node.get("config") or {})
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_id = config.get("model_id")
if model_id:
config["model_ref"] = self._model_ref(model_id)
del config["model_id"]
enriched_nodes.append({**node, "config": config})
return enriched_nodes
def _skill_ref(self, skill_id) -> Optional[dict]: def _skill_ref(self, skill_id) -> Optional[dict]:
if not skill_id: if not skill_id:
return None return None
@@ -620,16 +642,16 @@ class AppDslService:
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置") warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
config["knowledge_bases"] = resolved_kbs config["knowledge_bases"] = resolved_kbs
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value): elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_ref = config.get("model_id") model_ref = config.get("model_ref") or config.get("model_id")
if model_ref: if model_ref:
ref_dict = None ref_dict = None
if isinstance(model_ref, dict): if isinstance(model_ref, dict):
ref_id = model_ref.get("id") ref_dict = {
ref_name = model_ref.get("name") "id": model_ref.get("id"),
if ref_id: "name": model_ref.get("name"),
ref_dict = {"id": ref_id} "provider": model_ref.get("provider"),
elif ref_name is not None: "type": model_ref.get("type")
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")} }
elif isinstance(model_ref, str): elif isinstance(model_ref, str):
try: try:
uuid.UUID(model_ref) uuid.UUID(model_ref)
@@ -640,12 +662,18 @@ class AppDslService:
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings) resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
if resolved_model_id: if resolved_model_id:
config["model_id"] = resolved_model_id config["model_id"] = resolved_model_id
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
resolved_nodes.append({**node, "config": config}) resolved_nodes.append({**node, "config": config})
return resolved_nodes return resolved_nodes

View File

@@ -34,26 +34,7 @@ def generate_file_key(
Generate a unique file key for storage. Generate a unique file key for storage.
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext} The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
Args:
tenant_id: The tenant UUID.
workspace_id: The workspace UUID.
file_id: The file UUID.
file_ext: The file extension (e.g., '.pdf', '.txt').
Returns:
A unique file key string.
Example:
>>> generate_file_key(
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
... '.pdf'
... )
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
""" """
# Ensure file_ext starts with a dot
if file_ext and not file_ext.startswith('.'): if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}' file_ext = f'.{file_ext}'
if workspace_id: if workspace_id:
@@ -61,6 +42,21 @@ def generate_file_key(
return f"{tenant_id}/{file_id}{file_ext}" return f"{tenant_id}/{file_id}{file_ext}"
def generate_kb_file_key(
kb_id: uuid.UUID,
file_id: uuid.UUID,
file_ext: str,
) -> str:
"""
Generate a file key for knowledge base files.
Format: kb/{kb_id}/{file_id}{file_ext}
"""
if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}'
return f"kb/{kb_id}/{file_id}{file_ext}"
class FileStorageService: class FileStorageService:
""" """
High-level service for file storage operations. High-level service for file storage operations.

View File

@@ -210,9 +210,14 @@ def _build_vision_model(file_path: str, db_knowledge):
@celery_app.task(name="app.core.rag.tasks.parse_document") @celery_app.task(name="app.core.rag.tasks.parse_document")
def parse_document(file_path: str, document_id: uuid.UUID): def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
""" """
Document parsing, vectorization, and storage Document parsing, vectorization, and storage.
Args:
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
document_id: Document UUID
file_name: Original file name (used for extension detection in chunk())
""" """
db_document = None db_document = None
@@ -223,7 +228,6 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with get_db_context() as db: with get_db_context() as db:
try: try:
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
if not isinstance(document_id, uuid.UUID): if not isinstance(document_id, uuid.UUID):
document_id = uuid.UUID(str(document_id)) document_id = uuid.UUID(str(document_id))
@@ -234,7 +238,11 @@ def parse_document(file_path: str, document_id: uuid.UUID):
if db_knowledge is None: if db_knowledge is None:
raise ValueError(f"Knowledge {db_document.kb_id} not found") raise ValueError(f"Knowledge {db_document.kb_id} not found")
# 1. Document parsing & segmentation # Use file_name from argument or fall back to document record
if not file_name:
file_name = db_document.file_name
# 1. Download file from storage backend
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.") progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
start_time = time.time() start_time = time.time()
db_document.progress = 0.0 db_document.progress = 0.0
@@ -245,45 +253,36 @@ def parse_document(file_path: str, document_id: uuid.UUID):
db.commit() db.commit()
db.refresh(db_document) db.refresh(db_document)
# Read file content from storage backend (no NFS dependency)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
# If there's already a running loop (e.g. in some worker configurations)
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
if not file_binary:
raise IOError(f"Downloaded empty file from storage: {file_key}")
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
def progress_callback(prog=None, msg=None): def progress_callback(prog=None, msg=None):
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.") progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
# Prepare vision_model for parsing # Prepare vision_model for parsing
vision_model = _build_vision_model(file_path, db_knowledge) vision_model = _build_vision_model(file_name, db_knowledge)
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
# python-docx 等库在 binary=None 时会用路径直接打开文件,
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
max_wait_seconds = 30
wait_interval = 2
waited = 0
file_binary = None
while waited <= max_wait_seconds:
# os.listdir 强制 NFS 客户端刷新目录缓存
parent_dir = os.path.dirname(file_path)
try:
os.listdir(parent_dir)
except OSError:
pass
try:
with open(file_path, "rb") as f:
file_binary = f.read()
if not file_binary:
# NFS 上文件存在但内容为空(可能还在同步中)
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
break
except (FileNotFoundError, IOError) as e:
if waited >= max_wait_seconds:
raise type(e)(
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
)
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
time.sleep(wait_interval)
waited += wait_interval
from app.core.rag.app.naive import chunk from app.core.rag.app.naive import chunk
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}") logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
res = chunk(filename=file_path, res = chunk(filename=file_name,
binary=file_binary, binary=file_binary,
from_page=0, from_page=0,
to_page=DEFAULT_PARSE_TO_PAGE, to_page=DEFAULT_PARSE_TO_PAGE,